Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_pndm.py
1448 views
1
import tempfile
2
3
import torch
4
5
from diffusers import PNDMScheduler
6
7
from .test_schedulers import SchedulerCommonTest
8
9
10
class PNDMSchedulerTest(SchedulerCommonTest):
11
scheduler_classes = (PNDMScheduler,)
12
forward_default_kwargs = (("num_inference_steps", 50),)
13
14
def get_scheduler_config(self, **kwargs):
15
config = {
16
"num_train_timesteps": 1000,
17
"beta_start": 0.0001,
18
"beta_end": 0.02,
19
"beta_schedule": "linear",
20
}
21
22
config.update(**kwargs)
23
return config
24
25
def check_over_configs(self, time_step=0, **config):
26
kwargs = dict(self.forward_default_kwargs)
27
num_inference_steps = kwargs.pop("num_inference_steps", None)
28
sample = self.dummy_sample
29
residual = 0.1 * sample
30
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
31
32
for scheduler_class in self.scheduler_classes:
33
scheduler_config = self.get_scheduler_config(**config)
34
scheduler = scheduler_class(**scheduler_config)
35
scheduler.set_timesteps(num_inference_steps)
36
# copy over dummy past residuals
37
scheduler.ets = dummy_past_residuals[:]
38
39
with tempfile.TemporaryDirectory() as tmpdirname:
40
scheduler.save_config(tmpdirname)
41
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
42
new_scheduler.set_timesteps(num_inference_steps)
43
# copy over dummy past residuals
44
new_scheduler.ets = dummy_past_residuals[:]
45
46
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
47
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
48
49
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
50
51
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
52
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
53
54
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
55
56
def test_from_save_pretrained(self):
57
pass
58
59
def check_over_forward(self, time_step=0, **forward_kwargs):
60
kwargs = dict(self.forward_default_kwargs)
61
num_inference_steps = kwargs.pop("num_inference_steps", None)
62
sample = self.dummy_sample
63
residual = 0.1 * sample
64
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
65
66
for scheduler_class in self.scheduler_classes:
67
scheduler_config = self.get_scheduler_config()
68
scheduler = scheduler_class(**scheduler_config)
69
scheduler.set_timesteps(num_inference_steps)
70
71
# copy over dummy past residuals (must be after setting timesteps)
72
scheduler.ets = dummy_past_residuals[:]
73
74
with tempfile.TemporaryDirectory() as tmpdirname:
75
scheduler.save_config(tmpdirname)
76
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
77
# copy over dummy past residuals
78
new_scheduler.set_timesteps(num_inference_steps)
79
80
# copy over dummy past residual (must be after setting timesteps)
81
new_scheduler.ets = dummy_past_residuals[:]
82
83
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
84
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
85
86
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
87
88
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
89
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
90
91
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
92
93
def full_loop(self, **config):
94
scheduler_class = self.scheduler_classes[0]
95
scheduler_config = self.get_scheduler_config(**config)
96
scheduler = scheduler_class(**scheduler_config)
97
98
num_inference_steps = 10
99
model = self.dummy_model()
100
sample = self.dummy_sample_deter
101
scheduler.set_timesteps(num_inference_steps)
102
103
for i, t in enumerate(scheduler.prk_timesteps):
104
residual = model(sample, t)
105
sample = scheduler.step_prk(residual, t, sample).prev_sample
106
107
for i, t in enumerate(scheduler.plms_timesteps):
108
residual = model(sample, t)
109
sample = scheduler.step_plms(residual, t, sample).prev_sample
110
111
return sample
112
113
def test_step_shape(self):
114
kwargs = dict(self.forward_default_kwargs)
115
116
num_inference_steps = kwargs.pop("num_inference_steps", None)
117
118
for scheduler_class in self.scheduler_classes:
119
scheduler_config = self.get_scheduler_config()
120
scheduler = scheduler_class(**scheduler_config)
121
122
sample = self.dummy_sample
123
residual = 0.1 * sample
124
125
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
126
scheduler.set_timesteps(num_inference_steps)
127
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
128
kwargs["num_inference_steps"] = num_inference_steps
129
130
# copy over dummy past residuals (must be done after set_timesteps)
131
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
132
scheduler.ets = dummy_past_residuals[:]
133
134
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
135
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
136
137
self.assertEqual(output_0.shape, sample.shape)
138
self.assertEqual(output_0.shape, output_1.shape)
139
140
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
141
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
142
143
self.assertEqual(output_0.shape, sample.shape)
144
self.assertEqual(output_0.shape, output_1.shape)
145
146
def test_timesteps(self):
147
for timesteps in [100, 1000]:
148
self.check_over_configs(num_train_timesteps=timesteps)
149
150
def test_steps_offset(self):
151
for steps_offset in [0, 1]:
152
self.check_over_configs(steps_offset=steps_offset)
153
154
scheduler_class = self.scheduler_classes[0]
155
scheduler_config = self.get_scheduler_config(steps_offset=1)
156
scheduler = scheduler_class(**scheduler_config)
157
scheduler.set_timesteps(10)
158
assert torch.equal(
159
scheduler.timesteps,
160
torch.LongTensor(
161
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
162
),
163
)
164
165
def test_betas(self):
166
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
167
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
168
169
def test_schedules(self):
170
for schedule in ["linear", "squaredcos_cap_v2"]:
171
self.check_over_configs(beta_schedule=schedule)
172
173
def test_prediction_type(self):
174
for prediction_type in ["epsilon", "v_prediction"]:
175
self.check_over_configs(prediction_type=prediction_type)
176
177
def test_time_indices(self):
178
for t in [1, 5, 10]:
179
self.check_over_forward(time_step=t)
180
181
def test_inference_steps(self):
182
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
183
self.check_over_forward(num_inference_steps=num_inference_steps)
184
185
def test_pow_of_3_inference_steps(self):
186
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
187
num_inference_steps = 27
188
189
for scheduler_class in self.scheduler_classes:
190
sample = self.dummy_sample
191
residual = 0.1 * sample
192
193
scheduler_config = self.get_scheduler_config()
194
scheduler = scheduler_class(**scheduler_config)
195
196
scheduler.set_timesteps(num_inference_steps)
197
198
# before power of 3 fix, would error on first step, so we only need to do two
199
for i, t in enumerate(scheduler.prk_timesteps[:2]):
200
sample = scheduler.step_prk(residual, t, sample).prev_sample
201
202
def test_inference_plms_no_past_residuals(self):
203
with self.assertRaises(ValueError):
204
scheduler_class = self.scheduler_classes[0]
205
scheduler_config = self.get_scheduler_config()
206
scheduler = scheduler_class(**scheduler_config)
207
208
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
209
210
def test_full_loop_no_noise(self):
211
sample = self.full_loop()
212
result_sum = torch.sum(torch.abs(sample))
213
result_mean = torch.mean(torch.abs(sample))
214
215
assert abs(result_sum.item() - 198.1318) < 1e-2
216
assert abs(result_mean.item() - 0.2580) < 1e-3
217
218
def test_full_loop_with_v_prediction(self):
219
sample = self.full_loop(prediction_type="v_prediction")
220
result_sum = torch.sum(torch.abs(sample))
221
result_mean = torch.mean(torch.abs(sample))
222
223
assert abs(result_sum.item() - 67.3986) < 1e-2
224
assert abs(result_mean.item() - 0.0878) < 1e-3
225
226
def test_full_loop_with_set_alpha_to_one(self):
227
# We specify different beta, so that the first alpha is 0.99
228
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
229
result_sum = torch.sum(torch.abs(sample))
230
result_mean = torch.mean(torch.abs(sample))
231
232
assert abs(result_sum.item() - 230.0399) < 1e-2
233
assert abs(result_mean.item() - 0.2995) < 1e-3
234
235
def test_full_loop_with_no_set_alpha_to_one(self):
236
# We specify different beta, so that the first alpha is 0.99
237
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
238
result_sum = torch.sum(torch.abs(sample))
239
result_mean = torch.mean(torch.abs(sample))
240
241
assert abs(result_sum.item() - 186.9482) < 1e-2
242
assert abs(result_mean.item() - 0.2434) < 1e-3
243
244