Path: blob/main/tests/schedulers/test_scheduler_vq_diffusion.py
1450 views
import torch1import torch.nn.functional as F23from diffusers import VQDiffusionScheduler45from .test_schedulers import SchedulerCommonTest678class VQDiffusionSchedulerTest(SchedulerCommonTest):9scheduler_classes = (VQDiffusionScheduler,)1011def get_scheduler_config(self, **kwargs):12config = {13"num_vec_classes": 4097,14"num_train_timesteps": 100,15}1617config.update(**kwargs)18return config1920def dummy_sample(self, num_vec_classes):21batch_size = 422height = 823width = 82425sample = torch.randint(0, num_vec_classes, (batch_size, height * width))2627return sample2829@property30def dummy_sample_deter(self):31assert False3233def dummy_model(self, num_vec_classes):34def model(sample, t, *args):35batch_size, num_latent_pixels = sample.shape36logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))37return_value = F.log_softmax(logits.double(), dim=1).float()38return return_value3940return model4142def test_timesteps(self):43for timesteps in [2, 5, 100, 1000]:44self.check_over_configs(num_train_timesteps=timesteps)4546def test_num_vec_classes(self):47for num_vec_classes in [5, 100, 1000, 4000]:48self.check_over_configs(num_vec_classes=num_vec_classes)4950def test_time_indices(self):51for t in [0, 50, 99]:52self.check_over_forward(time_step=t)5354def test_add_noise_device(self):55pass565758