Path: blob/main/tests/schedulers/test_scheduler_deis.py
1448 views
import tempfile12import torch34from diffusers import (5DEISMultistepScheduler,6DPMSolverMultistepScheduler,7DPMSolverSinglestepScheduler,8UniPCMultistepScheduler,9)1011from .test_schedulers import SchedulerCommonTest121314class DEISMultistepSchedulerTest(SchedulerCommonTest):15scheduler_classes = (DEISMultistepScheduler,)16forward_default_kwargs = (("num_inference_steps", 25),)1718def get_scheduler_config(self, **kwargs):19config = {20"num_train_timesteps": 1000,21"beta_start": 0.0001,22"beta_end": 0.02,23"beta_schedule": "linear",24"solver_order": 2,25}2627config.update(**kwargs)28return config2930def check_over_configs(self, time_step=0, **config):31kwargs = dict(self.forward_default_kwargs)32num_inference_steps = kwargs.pop("num_inference_steps", None)33sample = self.dummy_sample34residual = 0.1 * sample35dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]3637for scheduler_class in self.scheduler_classes:38scheduler_config = self.get_scheduler_config(**config)39scheduler = scheduler_class(**scheduler_config)40scheduler.set_timesteps(num_inference_steps)41# copy over dummy past residuals42scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]4344with tempfile.TemporaryDirectory() as tmpdirname:45scheduler.save_config(tmpdirname)46new_scheduler = scheduler_class.from_pretrained(tmpdirname)47new_scheduler.set_timesteps(num_inference_steps)48# copy over dummy past residuals49new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]5051output, new_output = sample, sample52for t in range(time_step, time_step + scheduler.config.solver_order + 1):53output = scheduler.step(residual, t, output, **kwargs).prev_sample54new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample5556assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"5758def test_from_save_pretrained(self):59pass6061def check_over_forward(self, time_step=0, **forward_kwargs):62kwargs = dict(self.forward_default_kwargs)63num_inference_steps = kwargs.pop("num_inference_steps", None)64sample = self.dummy_sample65residual = 0.1 * sample66dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]6768for scheduler_class in self.scheduler_classes:69scheduler_config = self.get_scheduler_config()70scheduler = scheduler_class(**scheduler_config)71scheduler.set_timesteps(num_inference_steps)7273# copy over dummy past residuals (must be after setting timesteps)74scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]7576with tempfile.TemporaryDirectory() as tmpdirname:77scheduler.save_config(tmpdirname)78new_scheduler = scheduler_class.from_pretrained(tmpdirname)79# copy over dummy past residuals80new_scheduler.set_timesteps(num_inference_steps)8182# copy over dummy past residual (must be after setting timesteps)83new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]8485output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample86new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample8788assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"8990def full_loop(self, scheduler=None, **config):91if scheduler is None:92scheduler_class = self.scheduler_classes[0]93scheduler_config = self.get_scheduler_config(**config)94scheduler = scheduler_class(**scheduler_config)9596scheduler_class = self.scheduler_classes[0]97scheduler_config = self.get_scheduler_config(**config)98scheduler = scheduler_class(**scheduler_config)99100num_inference_steps = 10101model = self.dummy_model()102sample = self.dummy_sample_deter103scheduler.set_timesteps(num_inference_steps)104105for i, t in enumerate(scheduler.timesteps):106residual = model(sample, t)107sample = scheduler.step(residual, t, sample).prev_sample108109return sample110111def test_step_shape(self):112kwargs = dict(self.forward_default_kwargs)113114num_inference_steps = kwargs.pop("num_inference_steps", None)115116for scheduler_class in self.scheduler_classes:117scheduler_config = self.get_scheduler_config()118scheduler = scheduler_class(**scheduler_config)119120sample = self.dummy_sample121residual = 0.1 * sample122123if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):124scheduler.set_timesteps(num_inference_steps)125elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):126kwargs["num_inference_steps"] = num_inference_steps127128# copy over dummy past residuals (must be done after set_timesteps)129dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]130scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]131132time_step_0 = scheduler.timesteps[5]133time_step_1 = scheduler.timesteps[6]134135output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample136output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample137138self.assertEqual(output_0.shape, sample.shape)139self.assertEqual(output_0.shape, output_1.shape)140141def test_switch(self):142# make sure that iterating over schedulers with same config names gives same results143# for defaults144scheduler = DEISMultistepScheduler(**self.get_scheduler_config())145sample = self.full_loop(scheduler=scheduler)146result_mean = torch.mean(torch.abs(sample))147148assert abs(result_mean.item() - 0.23916) < 1e-3149150scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)151scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)152scheduler = UniPCMultistepScheduler.from_config(scheduler.config)153scheduler = DEISMultistepScheduler.from_config(scheduler.config)154155sample = self.full_loop(scheduler=scheduler)156result_mean = torch.mean(torch.abs(sample))157158assert abs(result_mean.item() - 0.23916) < 1e-3159160def test_timesteps(self):161for timesteps in [25, 50, 100, 999, 1000]:162self.check_over_configs(num_train_timesteps=timesteps)163164def test_thresholding(self):165self.check_over_configs(thresholding=False)166for order in [1, 2, 3]:167for solver_type in ["logrho"]:168for threshold in [0.5, 1.0, 2.0]:169for prediction_type in ["epsilon", "sample"]:170self.check_over_configs(171thresholding=True,172prediction_type=prediction_type,173sample_max_value=threshold,174algorithm_type="deis",175solver_order=order,176solver_type=solver_type,177)178179def test_prediction_type(self):180for prediction_type in ["epsilon", "v_prediction"]:181self.check_over_configs(prediction_type=prediction_type)182183def test_solver_order_and_type(self):184for algorithm_type in ["deis"]:185for solver_type in ["logrho"]:186for order in [1, 2, 3]:187for prediction_type in ["epsilon", "sample"]:188self.check_over_configs(189solver_order=order,190solver_type=solver_type,191prediction_type=prediction_type,192algorithm_type=algorithm_type,193)194sample = self.full_loop(195solver_order=order,196solver_type=solver_type,197prediction_type=prediction_type,198algorithm_type=algorithm_type,199)200assert not torch.isnan(sample).any(), "Samples have nan numbers"201202def test_lower_order_final(self):203self.check_over_configs(lower_order_final=True)204self.check_over_configs(lower_order_final=False)205206def test_inference_steps(self):207for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:208self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)209210def test_full_loop_no_noise(self):211sample = self.full_loop()212result_mean = torch.mean(torch.abs(sample))213214assert abs(result_mean.item() - 0.23916) < 1e-3215216def test_full_loop_with_v_prediction(self):217sample = self.full_loop(prediction_type="v_prediction")218result_mean = torch.mean(torch.abs(sample))219220assert abs(result_mean.item() - 0.091) < 1e-3221222def test_fp16_support(self):223scheduler_class = self.scheduler_classes[0]224scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)225scheduler = scheduler_class(**scheduler_config)226227num_inference_steps = 10228model = self.dummy_model()229sample = self.dummy_sample_deter.half()230scheduler.set_timesteps(num_inference_steps)231232for i, t in enumerate(scheduler.timesteps):233residual = model(sample, t)234sample = scheduler.step(residual, t, sample).prev_sample235236assert sample.dtype == torch.float16237238239