Path: blob/main/tests/schedulers/test_scheduler_pndm.py
1448 views
import tempfile12import torch34from diffusers import PNDMScheduler56from .test_schedulers import SchedulerCommonTest789class PNDMSchedulerTest(SchedulerCommonTest):10scheduler_classes = (PNDMScheduler,)11forward_default_kwargs = (("num_inference_steps", 50),)1213def get_scheduler_config(self, **kwargs):14config = {15"num_train_timesteps": 1000,16"beta_start": 0.0001,17"beta_end": 0.02,18"beta_schedule": "linear",19}2021config.update(**kwargs)22return config2324def check_over_configs(self, time_step=0, **config):25kwargs = dict(self.forward_default_kwargs)26num_inference_steps = kwargs.pop("num_inference_steps", None)27sample = self.dummy_sample28residual = 0.1 * sample29dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]3031for scheduler_class in self.scheduler_classes:32scheduler_config = self.get_scheduler_config(**config)33scheduler = scheduler_class(**scheduler_config)34scheduler.set_timesteps(num_inference_steps)35# copy over dummy past residuals36scheduler.ets = dummy_past_residuals[:]3738with tempfile.TemporaryDirectory() as tmpdirname:39scheduler.save_config(tmpdirname)40new_scheduler = scheduler_class.from_pretrained(tmpdirname)41new_scheduler.set_timesteps(num_inference_steps)42# copy over dummy past residuals43new_scheduler.ets = dummy_past_residuals[:]4445output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample46new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample4748assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"4950output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample51new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample5253assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"5455def test_from_save_pretrained(self):56pass5758def check_over_forward(self, time_step=0, **forward_kwargs):59kwargs = dict(self.forward_default_kwargs)60num_inference_steps = kwargs.pop("num_inference_steps", None)61sample = self.dummy_sample62residual = 0.1 * sample63dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]6465for scheduler_class in self.scheduler_classes:66scheduler_config = self.get_scheduler_config()67scheduler = scheduler_class(**scheduler_config)68scheduler.set_timesteps(num_inference_steps)6970# copy over dummy past residuals (must be after setting timesteps)71scheduler.ets = dummy_past_residuals[:]7273with tempfile.TemporaryDirectory() as tmpdirname:74scheduler.save_config(tmpdirname)75new_scheduler = scheduler_class.from_pretrained(tmpdirname)76# copy over dummy past residuals77new_scheduler.set_timesteps(num_inference_steps)7879# copy over dummy past residual (must be after setting timesteps)80new_scheduler.ets = dummy_past_residuals[:]8182output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample83new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample8485assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"8687output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample88new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample8990assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"9192def full_loop(self, **config):93scheduler_class = self.scheduler_classes[0]94scheduler_config = self.get_scheduler_config(**config)95scheduler = scheduler_class(**scheduler_config)9697num_inference_steps = 1098model = self.dummy_model()99sample = self.dummy_sample_deter100scheduler.set_timesteps(num_inference_steps)101102for i, t in enumerate(scheduler.prk_timesteps):103residual = model(sample, t)104sample = scheduler.step_prk(residual, t, sample).prev_sample105106for i, t in enumerate(scheduler.plms_timesteps):107residual = model(sample, t)108sample = scheduler.step_plms(residual, t, sample).prev_sample109110return sample111112def test_step_shape(self):113kwargs = dict(self.forward_default_kwargs)114115num_inference_steps = kwargs.pop("num_inference_steps", None)116117for scheduler_class in self.scheduler_classes:118scheduler_config = self.get_scheduler_config()119scheduler = scheduler_class(**scheduler_config)120121sample = self.dummy_sample122residual = 0.1 * sample123124if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):125scheduler.set_timesteps(num_inference_steps)126elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):127kwargs["num_inference_steps"] = num_inference_steps128129# copy over dummy past residuals (must be done after set_timesteps)130dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]131scheduler.ets = dummy_past_residuals[:]132133output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample134output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample135136self.assertEqual(output_0.shape, sample.shape)137self.assertEqual(output_0.shape, output_1.shape)138139output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample140output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample141142self.assertEqual(output_0.shape, sample.shape)143self.assertEqual(output_0.shape, output_1.shape)144145def test_timesteps(self):146for timesteps in [100, 1000]:147self.check_over_configs(num_train_timesteps=timesteps)148149def test_steps_offset(self):150for steps_offset in [0, 1]:151self.check_over_configs(steps_offset=steps_offset)152153scheduler_class = self.scheduler_classes[0]154scheduler_config = self.get_scheduler_config(steps_offset=1)155scheduler = scheduler_class(**scheduler_config)156scheduler.set_timesteps(10)157assert torch.equal(158scheduler.timesteps,159torch.LongTensor(160[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]161),162)163164def test_betas(self):165for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):166self.check_over_configs(beta_start=beta_start, beta_end=beta_end)167168def test_schedules(self):169for schedule in ["linear", "squaredcos_cap_v2"]:170self.check_over_configs(beta_schedule=schedule)171172def test_prediction_type(self):173for prediction_type in ["epsilon", "v_prediction"]:174self.check_over_configs(prediction_type=prediction_type)175176def test_time_indices(self):177for t in [1, 5, 10]:178self.check_over_forward(time_step=t)179180def test_inference_steps(self):181for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):182self.check_over_forward(num_inference_steps=num_inference_steps)183184def test_pow_of_3_inference_steps(self):185# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3186num_inference_steps = 27187188for scheduler_class in self.scheduler_classes:189sample = self.dummy_sample190residual = 0.1 * sample191192scheduler_config = self.get_scheduler_config()193scheduler = scheduler_class(**scheduler_config)194195scheduler.set_timesteps(num_inference_steps)196197# before power of 3 fix, would error on first step, so we only need to do two198for i, t in enumerate(scheduler.prk_timesteps[:2]):199sample = scheduler.step_prk(residual, t, sample).prev_sample200201def test_inference_plms_no_past_residuals(self):202with self.assertRaises(ValueError):203scheduler_class = self.scheduler_classes[0]204scheduler_config = self.get_scheduler_config()205scheduler = scheduler_class(**scheduler_config)206207scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample208209def test_full_loop_no_noise(self):210sample = self.full_loop()211result_sum = torch.sum(torch.abs(sample))212result_mean = torch.mean(torch.abs(sample))213214assert abs(result_sum.item() - 198.1318) < 1e-2215assert abs(result_mean.item() - 0.2580) < 1e-3216217def test_full_loop_with_v_prediction(self):218sample = self.full_loop(prediction_type="v_prediction")219result_sum = torch.sum(torch.abs(sample))220result_mean = torch.mean(torch.abs(sample))221222assert abs(result_sum.item() - 67.3986) < 1e-2223assert abs(result_mean.item() - 0.0878) < 1e-3224225def test_full_loop_with_set_alpha_to_one(self):226# We specify different beta, so that the first alpha is 0.99227sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)228result_sum = torch.sum(torch.abs(sample))229result_mean = torch.mean(torch.abs(sample))230231assert abs(result_sum.item() - 230.0399) < 1e-2232assert abs(result_mean.item() - 0.2995) < 1e-3233234def test_full_loop_with_no_set_alpha_to_one(self):235# We specify different beta, so that the first alpha is 0.99236sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)237result_sum = torch.sum(torch.abs(sample))238result_mean = torch.mean(torch.abs(sample))239240assert abs(result_sum.item() - 186.9482) < 1e-2241assert abs(result_mean.item() - 0.2434) < 1e-3242243244