Path: blob/main/tests/schedulers/test_scheduler_dpm_single.py
1448 views
import tempfile12import torch34from diffusers import (5DEISMultistepScheduler,6DPMSolverMultistepScheduler,7DPMSolverSinglestepScheduler,8UniPCMultistepScheduler,9)1011from .test_schedulers import SchedulerCommonTest121314class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):15scheduler_classes = (DPMSolverSinglestepScheduler,)16forward_default_kwargs = (("num_inference_steps", 25),)1718def get_scheduler_config(self, **kwargs):19config = {20"num_train_timesteps": 1000,21"beta_start": 0.0001,22"beta_end": 0.02,23"beta_schedule": "linear",24"solver_order": 2,25"prediction_type": "epsilon",26"thresholding": False,27"sample_max_value": 1.0,28"algorithm_type": "dpmsolver++",29"solver_type": "midpoint",30}3132config.update(**kwargs)33return config3435def check_over_configs(self, time_step=0, **config):36kwargs = dict(self.forward_default_kwargs)37num_inference_steps = kwargs.pop("num_inference_steps", None)38sample = self.dummy_sample39residual = 0.1 * sample40dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]4142for scheduler_class in self.scheduler_classes:43scheduler_config = self.get_scheduler_config(**config)44scheduler = scheduler_class(**scheduler_config)45scheduler.set_timesteps(num_inference_steps)46# copy over dummy past residuals47scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]4849with tempfile.TemporaryDirectory() as tmpdirname:50scheduler.save_config(tmpdirname)51new_scheduler = scheduler_class.from_pretrained(tmpdirname)52new_scheduler.set_timesteps(num_inference_steps)53# copy over dummy past residuals54new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]5556output, new_output = sample, sample57for t in range(time_step, time_step + scheduler.config.solver_order + 1):58output = scheduler.step(residual, t, output, **kwargs).prev_sample59new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample6061assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"6263def test_from_save_pretrained(self):64pass6566def check_over_forward(self, time_step=0, **forward_kwargs):67kwargs = dict(self.forward_default_kwargs)68num_inference_steps = kwargs.pop("num_inference_steps", None)69sample = self.dummy_sample70residual = 0.1 * sample71dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]7273for scheduler_class in self.scheduler_classes:74scheduler_config = self.get_scheduler_config()75scheduler = scheduler_class(**scheduler_config)76scheduler.set_timesteps(num_inference_steps)7778# copy over dummy past residuals (must be after setting timesteps)79scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]8081with tempfile.TemporaryDirectory() as tmpdirname:82scheduler.save_config(tmpdirname)83new_scheduler = scheduler_class.from_pretrained(tmpdirname)84# copy over dummy past residuals85new_scheduler.set_timesteps(num_inference_steps)8687# copy over dummy past residual (must be after setting timesteps)88new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]8990output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample91new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample9293assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"9495def full_loop(self, scheduler=None, **config):96if scheduler is None:97scheduler_class = self.scheduler_classes[0]98scheduler_config = self.get_scheduler_config(**config)99scheduler = scheduler_class(**scheduler_config)100101scheduler_class = self.scheduler_classes[0]102scheduler_config = self.get_scheduler_config(**config)103scheduler = scheduler_class(**scheduler_config)104105num_inference_steps = 10106model = self.dummy_model()107sample = self.dummy_sample_deter108scheduler.set_timesteps(num_inference_steps)109110for i, t in enumerate(scheduler.timesteps):111residual = model(sample, t)112sample = scheduler.step(residual, t, sample).prev_sample113114return sample115116def test_timesteps(self):117for timesteps in [25, 50, 100, 999, 1000]:118self.check_over_configs(num_train_timesteps=timesteps)119120def test_switch(self):121# make sure that iterating over schedulers with same config names gives same results122# for defaults123scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())124sample = self.full_loop(scheduler=scheduler)125result_mean = torch.mean(torch.abs(sample))126127assert abs(result_mean.item() - 0.2791) < 1e-3128129scheduler = DEISMultistepScheduler.from_config(scheduler.config)130scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)131scheduler = UniPCMultistepScheduler.from_config(scheduler.config)132scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)133134sample = self.full_loop(scheduler=scheduler)135result_mean = torch.mean(torch.abs(sample))136137assert abs(result_mean.item() - 0.2791) < 1e-3138139def test_thresholding(self):140self.check_over_configs(thresholding=False)141for order in [1, 2, 3]:142for solver_type in ["midpoint", "heun"]:143for threshold in [0.5, 1.0, 2.0]:144for prediction_type in ["epsilon", "sample"]:145self.check_over_configs(146thresholding=True,147prediction_type=prediction_type,148sample_max_value=threshold,149algorithm_type="dpmsolver++",150solver_order=order,151solver_type=solver_type,152)153154def test_prediction_type(self):155for prediction_type in ["epsilon", "v_prediction"]:156self.check_over_configs(prediction_type=prediction_type)157158def test_solver_order_and_type(self):159for algorithm_type in ["dpmsolver", "dpmsolver++"]:160for solver_type in ["midpoint", "heun"]:161for order in [1, 2, 3]:162for prediction_type in ["epsilon", "sample"]:163self.check_over_configs(164solver_order=order,165solver_type=solver_type,166prediction_type=prediction_type,167algorithm_type=algorithm_type,168)169sample = self.full_loop(170solver_order=order,171solver_type=solver_type,172prediction_type=prediction_type,173algorithm_type=algorithm_type,174)175assert not torch.isnan(sample).any(), "Samples have nan numbers"176177def test_lower_order_final(self):178self.check_over_configs(lower_order_final=True)179self.check_over_configs(lower_order_final=False)180181def test_inference_steps(self):182for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:183self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)184185def test_full_loop_no_noise(self):186sample = self.full_loop()187result_mean = torch.mean(torch.abs(sample))188189assert abs(result_mean.item() - 0.2791) < 1e-3190191def test_full_loop_with_v_prediction(self):192sample = self.full_loop(prediction_type="v_prediction")193result_mean = torch.mean(torch.abs(sample))194195assert abs(result_mean.item() - 0.1453) < 1e-3196197def test_fp16_support(self):198scheduler_class = self.scheduler_classes[0]199scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)200scheduler = scheduler_class(**scheduler_config)201202num_inference_steps = 10203model = self.dummy_model()204sample = self.dummy_sample_deter.half()205scheduler.set_timesteps(num_inference_steps)206207for i, t in enumerate(scheduler.timesteps):208residual = model(sample, t)209sample = scheduler.step(residual, t, sample).prev_sample210211assert sample.dtype == torch.float16212213214