Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_vq_diffusion.py
1450 views
1
import torch
2
import torch.nn.functional as F
3
4
from diffusers import VQDiffusionScheduler
5
6
from .test_schedulers import SchedulerCommonTest
7
8
9
class VQDiffusionSchedulerTest(SchedulerCommonTest):
10
scheduler_classes = (VQDiffusionScheduler,)
11
12
def get_scheduler_config(self, **kwargs):
13
config = {
14
"num_vec_classes": 4097,
15
"num_train_timesteps": 100,
16
}
17
18
config.update(**kwargs)
19
return config
20
21
def dummy_sample(self, num_vec_classes):
22
batch_size = 4
23
height = 8
24
width = 8
25
26
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
27
28
return sample
29
30
@property
31
def dummy_sample_deter(self):
32
assert False
33
34
def dummy_model(self, num_vec_classes):
35
def model(sample, t, *args):
36
batch_size, num_latent_pixels = sample.shape
37
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
38
return_value = F.log_softmax(logits.double(), dim=1).float()
39
return return_value
40
41
return model
42
43
def test_timesteps(self):
44
for timesteps in [2, 5, 100, 1000]:
45
self.check_over_configs(num_train_timesteps=timesteps)
46
47
def test_num_vec_classes(self):
48
for num_vec_classes in [5, 100, 1000, 4000]:
49
self.check_over_configs(num_vec_classes=num_vec_classes)
50
51
def test_time_indices(self):
52
for t in [0, 50, 99]:
53
self.check_over_forward(time_step=t)
54
55
def test_add_noise_device(self):
56
pass
57
58