Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_deis.py
1448 views
1
import tempfile
2
3
import torch
4
5
from diffusers import (
6
DEISMultistepScheduler,
7
DPMSolverMultistepScheduler,
8
DPMSolverSinglestepScheduler,
9
UniPCMultistepScheduler,
10
)
11
12
from .test_schedulers import SchedulerCommonTest
13
14
15
class DEISMultistepSchedulerTest(SchedulerCommonTest):
16
scheduler_classes = (DEISMultistepScheduler,)
17
forward_default_kwargs = (("num_inference_steps", 25),)
18
19
def get_scheduler_config(self, **kwargs):
20
config = {
21
"num_train_timesteps": 1000,
22
"beta_start": 0.0001,
23
"beta_end": 0.02,
24
"beta_schedule": "linear",
25
"solver_order": 2,
26
}
27
28
config.update(**kwargs)
29
return config
30
31
def check_over_configs(self, time_step=0, **config):
32
kwargs = dict(self.forward_default_kwargs)
33
num_inference_steps = kwargs.pop("num_inference_steps", None)
34
sample = self.dummy_sample
35
residual = 0.1 * sample
36
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
37
38
for scheduler_class in self.scheduler_classes:
39
scheduler_config = self.get_scheduler_config(**config)
40
scheduler = scheduler_class(**scheduler_config)
41
scheduler.set_timesteps(num_inference_steps)
42
# copy over dummy past residuals
43
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
44
45
with tempfile.TemporaryDirectory() as tmpdirname:
46
scheduler.save_config(tmpdirname)
47
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
48
new_scheduler.set_timesteps(num_inference_steps)
49
# copy over dummy past residuals
50
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
51
52
output, new_output = sample, sample
53
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
54
output = scheduler.step(residual, t, output, **kwargs).prev_sample
55
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
56
57
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
58
59
def test_from_save_pretrained(self):
60
pass
61
62
def check_over_forward(self, time_step=0, **forward_kwargs):
63
kwargs = dict(self.forward_default_kwargs)
64
num_inference_steps = kwargs.pop("num_inference_steps", None)
65
sample = self.dummy_sample
66
residual = 0.1 * sample
67
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
68
69
for scheduler_class in self.scheduler_classes:
70
scheduler_config = self.get_scheduler_config()
71
scheduler = scheduler_class(**scheduler_config)
72
scheduler.set_timesteps(num_inference_steps)
73
74
# copy over dummy past residuals (must be after setting timesteps)
75
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
76
77
with tempfile.TemporaryDirectory() as tmpdirname:
78
scheduler.save_config(tmpdirname)
79
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
80
# copy over dummy past residuals
81
new_scheduler.set_timesteps(num_inference_steps)
82
83
# copy over dummy past residual (must be after setting timesteps)
84
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
85
86
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
87
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
88
89
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
90
91
def full_loop(self, scheduler=None, **config):
92
if scheduler is None:
93
scheduler_class = self.scheduler_classes[0]
94
scheduler_config = self.get_scheduler_config(**config)
95
scheduler = scheduler_class(**scheduler_config)
96
97
scheduler_class = self.scheduler_classes[0]
98
scheduler_config = self.get_scheduler_config(**config)
99
scheduler = scheduler_class(**scheduler_config)
100
101
num_inference_steps = 10
102
model = self.dummy_model()
103
sample = self.dummy_sample_deter
104
scheduler.set_timesteps(num_inference_steps)
105
106
for i, t in enumerate(scheduler.timesteps):
107
residual = model(sample, t)
108
sample = scheduler.step(residual, t, sample).prev_sample
109
110
return sample
111
112
def test_step_shape(self):
113
kwargs = dict(self.forward_default_kwargs)
114
115
num_inference_steps = kwargs.pop("num_inference_steps", None)
116
117
for scheduler_class in self.scheduler_classes:
118
scheduler_config = self.get_scheduler_config()
119
scheduler = scheduler_class(**scheduler_config)
120
121
sample = self.dummy_sample
122
residual = 0.1 * sample
123
124
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
125
scheduler.set_timesteps(num_inference_steps)
126
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
127
kwargs["num_inference_steps"] = num_inference_steps
128
129
# copy over dummy past residuals (must be done after set_timesteps)
130
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
131
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
132
133
time_step_0 = scheduler.timesteps[5]
134
time_step_1 = scheduler.timesteps[6]
135
136
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
137
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
138
139
self.assertEqual(output_0.shape, sample.shape)
140
self.assertEqual(output_0.shape, output_1.shape)
141
142
def test_switch(self):
143
# make sure that iterating over schedulers with same config names gives same results
144
# for defaults
145
scheduler = DEISMultistepScheduler(**self.get_scheduler_config())
146
sample = self.full_loop(scheduler=scheduler)
147
result_mean = torch.mean(torch.abs(sample))
148
149
assert abs(result_mean.item() - 0.23916) < 1e-3
150
151
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
152
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
153
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
154
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
155
156
sample = self.full_loop(scheduler=scheduler)
157
result_mean = torch.mean(torch.abs(sample))
158
159
assert abs(result_mean.item() - 0.23916) < 1e-3
160
161
def test_timesteps(self):
162
for timesteps in [25, 50, 100, 999, 1000]:
163
self.check_over_configs(num_train_timesteps=timesteps)
164
165
def test_thresholding(self):
166
self.check_over_configs(thresholding=False)
167
for order in [1, 2, 3]:
168
for solver_type in ["logrho"]:
169
for threshold in [0.5, 1.0, 2.0]:
170
for prediction_type in ["epsilon", "sample"]:
171
self.check_over_configs(
172
thresholding=True,
173
prediction_type=prediction_type,
174
sample_max_value=threshold,
175
algorithm_type="deis",
176
solver_order=order,
177
solver_type=solver_type,
178
)
179
180
def test_prediction_type(self):
181
for prediction_type in ["epsilon", "v_prediction"]:
182
self.check_over_configs(prediction_type=prediction_type)
183
184
def test_solver_order_and_type(self):
185
for algorithm_type in ["deis"]:
186
for solver_type in ["logrho"]:
187
for order in [1, 2, 3]:
188
for prediction_type in ["epsilon", "sample"]:
189
self.check_over_configs(
190
solver_order=order,
191
solver_type=solver_type,
192
prediction_type=prediction_type,
193
algorithm_type=algorithm_type,
194
)
195
sample = self.full_loop(
196
solver_order=order,
197
solver_type=solver_type,
198
prediction_type=prediction_type,
199
algorithm_type=algorithm_type,
200
)
201
assert not torch.isnan(sample).any(), "Samples have nan numbers"
202
203
def test_lower_order_final(self):
204
self.check_over_configs(lower_order_final=True)
205
self.check_over_configs(lower_order_final=False)
206
207
def test_inference_steps(self):
208
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
209
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
210
211
def test_full_loop_no_noise(self):
212
sample = self.full_loop()
213
result_mean = torch.mean(torch.abs(sample))
214
215
assert abs(result_mean.item() - 0.23916) < 1e-3
216
217
def test_full_loop_with_v_prediction(self):
218
sample = self.full_loop(prediction_type="v_prediction")
219
result_mean = torch.mean(torch.abs(sample))
220
221
assert abs(result_mean.item() - 0.091) < 1e-3
222
223
def test_fp16_support(self):
224
scheduler_class = self.scheduler_classes[0]
225
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
226
scheduler = scheduler_class(**scheduler_config)
227
228
num_inference_steps = 10
229
model = self.dummy_model()
230
sample = self.dummy_sample_deter.half()
231
scheduler.set_timesteps(num_inference_steps)
232
233
for i, t in enumerate(scheduler.timesteps):
234
residual = model(sample, t)
235
sample = scheduler.step(residual, t, sample).prev_sample
236
237
assert sample.dtype == torch.float16
238
239