Path: blob/main/tests/schedulers/test_scheduler_dpm_multi.py
1448 views
import tempfile12import torch34from diffusers import (5DEISMultistepScheduler,6DPMSolverMultistepScheduler,7DPMSolverSinglestepScheduler,8UniPCMultistepScheduler,9)1011from .test_schedulers import SchedulerCommonTest121314class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):15scheduler_classes = (DPMSolverMultistepScheduler,)16forward_default_kwargs = (("num_inference_steps", 25),)1718def get_scheduler_config(self, **kwargs):19config = {20"num_train_timesteps": 1000,21"beta_start": 0.0001,22"beta_end": 0.02,23"beta_schedule": "linear",24"solver_order": 2,25"prediction_type": "epsilon",26"thresholding": False,27"sample_max_value": 1.0,28"algorithm_type": "dpmsolver++",29"solver_type": "midpoint",30"lower_order_final": False,31}3233config.update(**kwargs)34return config3536def check_over_configs(self, time_step=0, **config):37kwargs = dict(self.forward_default_kwargs)38num_inference_steps = kwargs.pop("num_inference_steps", None)39sample = self.dummy_sample40residual = 0.1 * sample41dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]4243for scheduler_class in self.scheduler_classes:44scheduler_config = self.get_scheduler_config(**config)45scheduler = scheduler_class(**scheduler_config)46scheduler.set_timesteps(num_inference_steps)47# copy over dummy past residuals48scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]4950with tempfile.TemporaryDirectory() as tmpdirname:51scheduler.save_config(tmpdirname)52new_scheduler = scheduler_class.from_pretrained(tmpdirname)53new_scheduler.set_timesteps(num_inference_steps)54# copy over dummy past residuals55new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]5657output, new_output = sample, sample58for t in range(time_step, time_step + scheduler.config.solver_order + 1):59output = scheduler.step(residual, t, output, **kwargs).prev_sample60new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample6162assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"6364def test_from_save_pretrained(self):65pass6667def check_over_forward(self, time_step=0, **forward_kwargs):68kwargs = dict(self.forward_default_kwargs)69num_inference_steps = kwargs.pop("num_inference_steps", None)70sample = self.dummy_sample71residual = 0.1 * sample72dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]7374for scheduler_class in self.scheduler_classes:75scheduler_config = self.get_scheduler_config()76scheduler = scheduler_class(**scheduler_config)77scheduler.set_timesteps(num_inference_steps)7879# copy over dummy past residuals (must be after setting timesteps)80scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]8182with tempfile.TemporaryDirectory() as tmpdirname:83scheduler.save_config(tmpdirname)84new_scheduler = scheduler_class.from_pretrained(tmpdirname)85# copy over dummy past residuals86new_scheduler.set_timesteps(num_inference_steps)8788# copy over dummy past residual (must be after setting timesteps)89new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]9091output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample92new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample9394assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"9596def full_loop(self, scheduler=None, **config):97if scheduler is None:98scheduler_class = self.scheduler_classes[0]99scheduler_config = self.get_scheduler_config(**config)100scheduler = scheduler_class(**scheduler_config)101102num_inference_steps = 10103model = self.dummy_model()104sample = self.dummy_sample_deter105scheduler.set_timesteps(num_inference_steps)106107for i, t in enumerate(scheduler.timesteps):108residual = model(sample, t)109sample = scheduler.step(residual, t, sample).prev_sample110111return sample112113def test_step_shape(self):114kwargs = dict(self.forward_default_kwargs)115116num_inference_steps = kwargs.pop("num_inference_steps", None)117118for scheduler_class in self.scheduler_classes:119scheduler_config = self.get_scheduler_config()120scheduler = scheduler_class(**scheduler_config)121122sample = self.dummy_sample123residual = 0.1 * sample124125if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):126scheduler.set_timesteps(num_inference_steps)127elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):128kwargs["num_inference_steps"] = num_inference_steps129130# copy over dummy past residuals (must be done after set_timesteps)131dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]132scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]133134time_step_0 = scheduler.timesteps[5]135time_step_1 = scheduler.timesteps[6]136137output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample138output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample139140self.assertEqual(output_0.shape, sample.shape)141self.assertEqual(output_0.shape, output_1.shape)142143def test_timesteps(self):144for timesteps in [25, 50, 100, 999, 1000]:145self.check_over_configs(num_train_timesteps=timesteps)146147def test_thresholding(self):148self.check_over_configs(thresholding=False)149for order in [1, 2, 3]:150for solver_type in ["midpoint", "heun"]:151for threshold in [0.5, 1.0, 2.0]:152for prediction_type in ["epsilon", "sample"]:153self.check_over_configs(154thresholding=True,155prediction_type=prediction_type,156sample_max_value=threshold,157algorithm_type="dpmsolver++",158solver_order=order,159solver_type=solver_type,160)161162def test_prediction_type(self):163for prediction_type in ["epsilon", "v_prediction"]:164self.check_over_configs(prediction_type=prediction_type)165166def test_solver_order_and_type(self):167for algorithm_type in ["dpmsolver", "dpmsolver++"]:168for solver_type in ["midpoint", "heun"]:169for order in [1, 2, 3]:170for prediction_type in ["epsilon", "sample"]:171self.check_over_configs(172solver_order=order,173solver_type=solver_type,174prediction_type=prediction_type,175algorithm_type=algorithm_type,176)177sample = self.full_loop(178solver_order=order,179solver_type=solver_type,180prediction_type=prediction_type,181algorithm_type=algorithm_type,182)183assert not torch.isnan(sample).any(), "Samples have nan numbers"184185def test_lower_order_final(self):186self.check_over_configs(lower_order_final=True)187self.check_over_configs(lower_order_final=False)188189def test_inference_steps(self):190for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:191self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)192193def test_full_loop_no_noise(self):194sample = self.full_loop()195result_mean = torch.mean(torch.abs(sample))196197assert abs(result_mean.item() - 0.3301) < 1e-3198199def test_full_loop_no_noise_thres(self):200sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)201result_mean = torch.mean(torch.abs(sample))202203assert abs(result_mean.item() - 0.6405) < 1e-3204205def test_full_loop_with_v_prediction(self):206sample = self.full_loop(prediction_type="v_prediction")207result_mean = torch.mean(torch.abs(sample))208209assert abs(result_mean.item() - 0.2251) < 1e-3210211def test_switch(self):212# make sure that iterating over schedulers with same config names gives same results213# for defaults214scheduler = DPMSolverMultistepScheduler(**self.get_scheduler_config())215sample = self.full_loop(scheduler=scheduler)216result_mean = torch.mean(torch.abs(sample))217218assert abs(result_mean.item() - 0.3301) < 1e-3219220scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)221scheduler = UniPCMultistepScheduler.from_config(scheduler.config)222scheduler = DEISMultistepScheduler.from_config(scheduler.config)223scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)224225sample = self.full_loop(scheduler=scheduler)226result_mean = torch.mean(torch.abs(sample))227228assert abs(result_mean.item() - 0.3301) < 1e-3229230def test_fp16_support(self):231scheduler_class = self.scheduler_classes[0]232scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)233scheduler = scheduler_class(**scheduler_config)234235num_inference_steps = 10236model = self.dummy_model()237sample = self.dummy_sample_deter.half()238scheduler.set_timesteps(num_inference_steps)239240for i, t in enumerate(scheduler.timesteps):241residual = model(sample, t)242sample = scheduler.step(residual, t, sample).prev_sample243244assert sample.dtype == torch.float16245246247