Path: blob/main/tests/schedulers/test_scheduler_ipndm.py
1448 views
import tempfile12import torch34from diffusers import IPNDMScheduler56from .test_schedulers import SchedulerCommonTest789class IPNDMSchedulerTest(SchedulerCommonTest):10scheduler_classes = (IPNDMScheduler,)11forward_default_kwargs = (("num_inference_steps", 50),)1213def get_scheduler_config(self, **kwargs):14config = {"num_train_timesteps": 1000}15config.update(**kwargs)16return config1718def check_over_configs(self, time_step=0, **config):19kwargs = dict(self.forward_default_kwargs)20num_inference_steps = kwargs.pop("num_inference_steps", None)21sample = self.dummy_sample22residual = 0.1 * sample23dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]2425for scheduler_class in self.scheduler_classes:26scheduler_config = self.get_scheduler_config(**config)27scheduler = scheduler_class(**scheduler_config)28scheduler.set_timesteps(num_inference_steps)29# copy over dummy past residuals30scheduler.ets = dummy_past_residuals[:]3132if time_step is None:33time_step = scheduler.timesteps[len(scheduler.timesteps) // 2]3435with tempfile.TemporaryDirectory() as tmpdirname:36scheduler.save_config(tmpdirname)37new_scheduler = scheduler_class.from_pretrained(tmpdirname)38new_scheduler.set_timesteps(num_inference_steps)39# copy over dummy past residuals40new_scheduler.ets = dummy_past_residuals[:]4142output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample43new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample4445assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"4647output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample48new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample4950assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"5152def test_from_save_pretrained(self):53pass5455def check_over_forward(self, time_step=0, **forward_kwargs):56kwargs = dict(self.forward_default_kwargs)57num_inference_steps = kwargs.pop("num_inference_steps", None)58sample = self.dummy_sample59residual = 0.1 * sample60dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]6162for scheduler_class in self.scheduler_classes:63scheduler_config = self.get_scheduler_config()64scheduler = scheduler_class(**scheduler_config)65scheduler.set_timesteps(num_inference_steps)6667# copy over dummy past residuals (must be after setting timesteps)68scheduler.ets = dummy_past_residuals[:]6970if time_step is None:71time_step = scheduler.timesteps[len(scheduler.timesteps) // 2]7273with tempfile.TemporaryDirectory() as tmpdirname:74scheduler.save_config(tmpdirname)75new_scheduler = scheduler_class.from_pretrained(tmpdirname)76# copy over dummy past residuals77new_scheduler.set_timesteps(num_inference_steps)7879# copy over dummy past residual (must be after setting timesteps)80new_scheduler.ets = dummy_past_residuals[:]8182output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample83new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample8485assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"8687output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample88new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample8990assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"9192def full_loop(self, **config):93scheduler_class = self.scheduler_classes[0]94scheduler_config = self.get_scheduler_config(**config)95scheduler = scheduler_class(**scheduler_config)9697num_inference_steps = 1098model = self.dummy_model()99sample = self.dummy_sample_deter100scheduler.set_timesteps(num_inference_steps)101102for i, t in enumerate(scheduler.timesteps):103residual = model(sample, t)104sample = scheduler.step(residual, t, sample).prev_sample105106for i, t in enumerate(scheduler.timesteps):107residual = model(sample, t)108sample = scheduler.step(residual, t, sample).prev_sample109110return sample111112def test_step_shape(self):113kwargs = dict(self.forward_default_kwargs)114115num_inference_steps = kwargs.pop("num_inference_steps", None)116117for scheduler_class in self.scheduler_classes:118scheduler_config = self.get_scheduler_config()119scheduler = scheduler_class(**scheduler_config)120121sample = self.dummy_sample122residual = 0.1 * sample123124if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):125scheduler.set_timesteps(num_inference_steps)126elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):127kwargs["num_inference_steps"] = num_inference_steps128129# copy over dummy past residuals (must be done after set_timesteps)130dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]131scheduler.ets = dummy_past_residuals[:]132133time_step_0 = scheduler.timesteps[5]134time_step_1 = scheduler.timesteps[6]135136output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample137output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample138139self.assertEqual(output_0.shape, sample.shape)140self.assertEqual(output_0.shape, output_1.shape)141142output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample143output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample144145self.assertEqual(output_0.shape, sample.shape)146self.assertEqual(output_0.shape, output_1.shape)147148def test_timesteps(self):149for timesteps in [100, 1000]:150self.check_over_configs(num_train_timesteps=timesteps, time_step=None)151152def test_inference_steps(self):153for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):154self.check_over_forward(num_inference_steps=num_inference_steps, time_step=None)155156def test_full_loop_no_noise(self):157sample = self.full_loop()158result_mean = torch.mean(torch.abs(sample))159160assert abs(result_mean.item() - 2540529) < 10161162163