Path: blob/main/tests/schedulers/test_scheduler_unclip.py
1448 views
import torch12from diffusers import UnCLIPScheduler34from .test_schedulers import SchedulerCommonTest567# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.8class UnCLIPSchedulerTest(SchedulerCommonTest):9scheduler_classes = (UnCLIPScheduler,)1011def get_scheduler_config(self, **kwargs):12config = {13"num_train_timesteps": 1000,14"variance_type": "fixed_small_log",15"clip_sample": True,16"clip_sample_range": 1.0,17"prediction_type": "epsilon",18}1920config.update(**kwargs)21return config2223def test_timesteps(self):24for timesteps in [1, 5, 100, 1000]:25self.check_over_configs(num_train_timesteps=timesteps)2627def test_variance_type(self):28for variance in ["fixed_small_log", "learned_range"]:29self.check_over_configs(variance_type=variance)3031def test_clip_sample(self):32for clip_sample in [True, False]:33self.check_over_configs(clip_sample=clip_sample)3435def test_clip_sample_range(self):36for clip_sample_range in [1, 5, 10, 20]:37self.check_over_configs(clip_sample_range=clip_sample_range)3839def test_prediction_type(self):40for prediction_type in ["epsilon", "sample"]:41self.check_over_configs(prediction_type=prediction_type)4243def test_time_indices(self):44for time_step in [0, 500, 999]:45for prev_timestep in [None, 5, 100, 250, 500, 750]:46if prev_timestep is not None and prev_timestep >= time_step:47continue4849self.check_over_forward(time_step=time_step, prev_timestep=prev_timestep)5051def test_variance_fixed_small_log(self):52scheduler_class = self.scheduler_classes[0]53scheduler_config = self.get_scheduler_config(variance_type="fixed_small_log")54scheduler = scheduler_class(**scheduler_config)5556assert torch.sum(torch.abs(scheduler._get_variance(0) - 1.0000e-10)) < 1e-557assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.0549625)) < 1e-558assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.9994987)) < 1e-55960def test_variance_learned_range(self):61scheduler_class = self.scheduler_classes[0]62scheduler_config = self.get_scheduler_config(variance_type="learned_range")63scheduler = scheduler_class(**scheduler_config)6465predicted_variance = 0.56667assert scheduler._get_variance(1, predicted_variance=predicted_variance) - -10.1712790 < 1e-568assert scheduler._get_variance(487, predicted_variance=predicted_variance) - -5.7998052 < 1e-569assert scheduler._get_variance(999, predicted_variance=predicted_variance) - -0.0010011 < 1e-57071def test_full_loop(self):72scheduler_class = self.scheduler_classes[0]73scheduler_config = self.get_scheduler_config()74scheduler = scheduler_class(**scheduler_config)7576timesteps = scheduler.timesteps7778model = self.dummy_model()79sample = self.dummy_sample_deter80generator = torch.manual_seed(0)8182for i, t in enumerate(timesteps):83# 1. predict noise residual84residual = model(sample, t)8586# 2. predict previous mean of sample x_t-187pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample8889sample = pred_prev_sample9091result_sum = torch.sum(torch.abs(sample))92result_mean = torch.mean(torch.abs(sample))9394assert abs(result_sum.item() - 252.2682495) < 1e-295assert abs(result_mean.item() - 0.3284743) < 1e-39697def test_full_loop_skip_timesteps(self):98scheduler_class = self.scheduler_classes[0]99scheduler_config = self.get_scheduler_config()100scheduler = scheduler_class(**scheduler_config)101102scheduler.set_timesteps(25)103104timesteps = scheduler.timesteps105106model = self.dummy_model()107sample = self.dummy_sample_deter108generator = torch.manual_seed(0)109110for i, t in enumerate(timesteps):111# 1. predict noise residual112residual = model(sample, t)113114if i + 1 == timesteps.shape[0]:115prev_timestep = None116else:117prev_timestep = timesteps[i + 1]118119# 2. predict previous mean of sample x_t-1120pred_prev_sample = scheduler.step(121residual, t, sample, prev_timestep=prev_timestep, generator=generator122).prev_sample123124sample = pred_prev_sample125126result_sum = torch.sum(torch.abs(sample))127result_mean = torch.mean(torch.abs(sample))128129assert abs(result_sum.item() - 258.2044983) < 1e-2130assert abs(result_mean.item() - 0.3362038) < 1e-3131132def test_trained_betas(self):133pass134135def test_add_noise_device(self):136pass137138139