Path: blob/main/tests/schedulers/test_scheduler_ddpm.py
1448 views
import torch12from diffusers import DDPMScheduler34from .test_schedulers import SchedulerCommonTest567class DDPMSchedulerTest(SchedulerCommonTest):8scheduler_classes = (DDPMScheduler,)910def get_scheduler_config(self, **kwargs):11config = {12"num_train_timesteps": 1000,13"beta_start": 0.0001,14"beta_end": 0.02,15"beta_schedule": "linear",16"variance_type": "fixed_small",17"clip_sample": True,18}1920config.update(**kwargs)21return config2223def test_timesteps(self):24for timesteps in [1, 5, 100, 1000]:25self.check_over_configs(num_train_timesteps=timesteps)2627def test_betas(self):28for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):29self.check_over_configs(beta_start=beta_start, beta_end=beta_end)3031def test_schedules(self):32for schedule in ["linear", "squaredcos_cap_v2"]:33self.check_over_configs(beta_schedule=schedule)3435def test_variance_type(self):36for variance in ["fixed_small", "fixed_large", "other"]:37self.check_over_configs(variance_type=variance)3839def test_clip_sample(self):40for clip_sample in [True, False]:41self.check_over_configs(clip_sample=clip_sample)4243def test_thresholding(self):44self.check_over_configs(thresholding=False)45for threshold in [0.5, 1.0, 2.0]:46for prediction_type in ["epsilon", "sample", "v_prediction"]:47self.check_over_configs(48thresholding=True,49prediction_type=prediction_type,50sample_max_value=threshold,51)5253def test_prediction_type(self):54for prediction_type in ["epsilon", "sample", "v_prediction"]:55self.check_over_configs(prediction_type=prediction_type)5657def test_time_indices(self):58for t in [0, 500, 999]:59self.check_over_forward(time_step=t)6061def test_variance(self):62scheduler_class = self.scheduler_classes[0]63scheduler_config = self.get_scheduler_config()64scheduler = scheduler_class(**scheduler_config)6566assert torch.sum(torch.abs(scheduler._get_variance(0) - 0.0)) < 1e-567assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-568assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-56970def test_full_loop_no_noise(self):71scheduler_class = self.scheduler_classes[0]72scheduler_config = self.get_scheduler_config()73scheduler = scheduler_class(**scheduler_config)7475num_trained_timesteps = len(scheduler)7677model = self.dummy_model()78sample = self.dummy_sample_deter79generator = torch.manual_seed(0)8081for t in reversed(range(num_trained_timesteps)):82# 1. predict noise residual83residual = model(sample, t)8485# 2. predict previous mean of sample x_t-186pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample8788# if t > 0:89# noise = self.dummy_sample_deter90# variance = scheduler.get_variance(t) ** (0.5) * noise91#92# sample = pred_prev_sample + variance93sample = pred_prev_sample9495result_sum = torch.sum(torch.abs(sample))96result_mean = torch.mean(torch.abs(sample))9798assert abs(result_sum.item() - 258.9606) < 1e-299assert abs(result_mean.item() - 0.3372) < 1e-3100101def test_full_loop_with_v_prediction(self):102scheduler_class = self.scheduler_classes[0]103scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")104scheduler = scheduler_class(**scheduler_config)105106num_trained_timesteps = len(scheduler)107108model = self.dummy_model()109sample = self.dummy_sample_deter110generator = torch.manual_seed(0)111112for t in reversed(range(num_trained_timesteps)):113# 1. predict noise residual114residual = model(sample, t)115116# 2. predict previous mean of sample x_t-1117pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample118119# if t > 0:120# noise = self.dummy_sample_deter121# variance = scheduler.get_variance(t) ** (0.5) * noise122#123# sample = pred_prev_sample + variance124sample = pred_prev_sample125126result_sum = torch.sum(torch.abs(sample))127result_mean = torch.mean(torch.abs(sample))128129assert abs(result_sum.item() - 202.0296) < 1e-2130assert abs(result_mean.item() - 0.2631) < 1e-3131132133