Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_ddpm.py
1448 views
1
import torch
2
3
from diffusers import DDPMScheduler
4
5
from .test_schedulers import SchedulerCommonTest
6
7
8
class DDPMSchedulerTest(SchedulerCommonTest):
9
scheduler_classes = (DDPMScheduler,)
10
11
def get_scheduler_config(self, **kwargs):
12
config = {
13
"num_train_timesteps": 1000,
14
"beta_start": 0.0001,
15
"beta_end": 0.02,
16
"beta_schedule": "linear",
17
"variance_type": "fixed_small",
18
"clip_sample": True,
19
}
20
21
config.update(**kwargs)
22
return config
23
24
def test_timesteps(self):
25
for timesteps in [1, 5, 100, 1000]:
26
self.check_over_configs(num_train_timesteps=timesteps)
27
28
def test_betas(self):
29
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
30
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
31
32
def test_schedules(self):
33
for schedule in ["linear", "squaredcos_cap_v2"]:
34
self.check_over_configs(beta_schedule=schedule)
35
36
def test_variance_type(self):
37
for variance in ["fixed_small", "fixed_large", "other"]:
38
self.check_over_configs(variance_type=variance)
39
40
def test_clip_sample(self):
41
for clip_sample in [True, False]:
42
self.check_over_configs(clip_sample=clip_sample)
43
44
def test_thresholding(self):
45
self.check_over_configs(thresholding=False)
46
for threshold in [0.5, 1.0, 2.0]:
47
for prediction_type in ["epsilon", "sample", "v_prediction"]:
48
self.check_over_configs(
49
thresholding=True,
50
prediction_type=prediction_type,
51
sample_max_value=threshold,
52
)
53
54
def test_prediction_type(self):
55
for prediction_type in ["epsilon", "sample", "v_prediction"]:
56
self.check_over_configs(prediction_type=prediction_type)
57
58
def test_time_indices(self):
59
for t in [0, 500, 999]:
60
self.check_over_forward(time_step=t)
61
62
def test_variance(self):
63
scheduler_class = self.scheduler_classes[0]
64
scheduler_config = self.get_scheduler_config()
65
scheduler = scheduler_class(**scheduler_config)
66
67
assert torch.sum(torch.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
68
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
69
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
70
71
def test_full_loop_no_noise(self):
72
scheduler_class = self.scheduler_classes[0]
73
scheduler_config = self.get_scheduler_config()
74
scheduler = scheduler_class(**scheduler_config)
75
76
num_trained_timesteps = len(scheduler)
77
78
model = self.dummy_model()
79
sample = self.dummy_sample_deter
80
generator = torch.manual_seed(0)
81
82
for t in reversed(range(num_trained_timesteps)):
83
# 1. predict noise residual
84
residual = model(sample, t)
85
86
# 2. predict previous mean of sample x_t-1
87
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
88
89
# if t > 0:
90
# noise = self.dummy_sample_deter
91
# variance = scheduler.get_variance(t) ** (0.5) * noise
92
#
93
# sample = pred_prev_sample + variance
94
sample = pred_prev_sample
95
96
result_sum = torch.sum(torch.abs(sample))
97
result_mean = torch.mean(torch.abs(sample))
98
99
assert abs(result_sum.item() - 258.9606) < 1e-2
100
assert abs(result_mean.item() - 0.3372) < 1e-3
101
102
def test_full_loop_with_v_prediction(self):
103
scheduler_class = self.scheduler_classes[0]
104
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
105
scheduler = scheduler_class(**scheduler_config)
106
107
num_trained_timesteps = len(scheduler)
108
109
model = self.dummy_model()
110
sample = self.dummy_sample_deter
111
generator = torch.manual_seed(0)
112
113
for t in reversed(range(num_trained_timesteps)):
114
# 1. predict noise residual
115
residual = model(sample, t)
116
117
# 2. predict previous mean of sample x_t-1
118
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
119
120
# if t > 0:
121
# noise = self.dummy_sample_deter
122
# variance = scheduler.get_variance(t) ** (0.5) * noise
123
#
124
# sample = pred_prev_sample + variance
125
sample = pred_prev_sample
126
127
result_sum = torch.sum(torch.abs(sample))
128
result_mean = torch.mean(torch.abs(sample))
129
130
assert abs(result_sum.item() - 202.0296) < 1e-2
131
assert abs(result_mean.item() - 0.2631) < 1e-3
132
133