Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_unipc.py
1448 views
1
import tempfile
2
3
import torch
4
5
from diffusers import (
6
DEISMultistepScheduler,
7
DPMSolverMultistepScheduler,
8
DPMSolverSinglestepScheduler,
9
UniPCMultistepScheduler,
10
)
11
12
from .test_schedulers import SchedulerCommonTest
13
14
15
class UniPCMultistepSchedulerTest(SchedulerCommonTest):
16
scheduler_classes = (UniPCMultistepScheduler,)
17
forward_default_kwargs = (("num_inference_steps", 25),)
18
19
def get_scheduler_config(self, **kwargs):
20
config = {
21
"num_train_timesteps": 1000,
22
"beta_start": 0.0001,
23
"beta_end": 0.02,
24
"beta_schedule": "linear",
25
"solver_order": 2,
26
"solver_type": "bh1",
27
}
28
29
config.update(**kwargs)
30
return config
31
32
def check_over_configs(self, time_step=0, **config):
33
kwargs = dict(self.forward_default_kwargs)
34
num_inference_steps = kwargs.pop("num_inference_steps", None)
35
sample = self.dummy_sample
36
residual = 0.1 * sample
37
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
38
39
for scheduler_class in self.scheduler_classes:
40
scheduler_config = self.get_scheduler_config(**config)
41
scheduler = scheduler_class(**scheduler_config)
42
scheduler.set_timesteps(num_inference_steps)
43
# copy over dummy past residuals
44
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
45
46
with tempfile.TemporaryDirectory() as tmpdirname:
47
scheduler.save_config(tmpdirname)
48
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
49
new_scheduler.set_timesteps(num_inference_steps)
50
# copy over dummy past residuals
51
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
52
53
output, new_output = sample, sample
54
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
55
output = scheduler.step(residual, t, output, **kwargs).prev_sample
56
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
57
58
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
59
60
def check_over_forward(self, time_step=0, **forward_kwargs):
61
kwargs = dict(self.forward_default_kwargs)
62
num_inference_steps = kwargs.pop("num_inference_steps", None)
63
sample = self.dummy_sample
64
residual = 0.1 * sample
65
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
66
67
for scheduler_class in self.scheduler_classes:
68
scheduler_config = self.get_scheduler_config()
69
scheduler = scheduler_class(**scheduler_config)
70
scheduler.set_timesteps(num_inference_steps)
71
72
# copy over dummy past residuals (must be after setting timesteps)
73
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
74
75
with tempfile.TemporaryDirectory() as tmpdirname:
76
scheduler.save_config(tmpdirname)
77
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
78
# copy over dummy past residuals
79
new_scheduler.set_timesteps(num_inference_steps)
80
81
# copy over dummy past residual (must be after setting timesteps)
82
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
83
84
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
85
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
86
87
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
88
89
def full_loop(self, scheduler=None, **config):
90
if scheduler is None:
91
scheduler_class = self.scheduler_classes[0]
92
scheduler_config = self.get_scheduler_config(**config)
93
scheduler = scheduler_class(**scheduler_config)
94
95
scheduler_class = self.scheduler_classes[0]
96
scheduler_config = self.get_scheduler_config(**config)
97
scheduler = scheduler_class(**scheduler_config)
98
99
num_inference_steps = 10
100
model = self.dummy_model()
101
sample = self.dummy_sample_deter
102
scheduler.set_timesteps(num_inference_steps)
103
104
for i, t in enumerate(scheduler.timesteps):
105
residual = model(sample, t)
106
sample = scheduler.step(residual, t, sample).prev_sample
107
108
return sample
109
110
def test_step_shape(self):
111
kwargs = dict(self.forward_default_kwargs)
112
113
num_inference_steps = kwargs.pop("num_inference_steps", None)
114
115
for scheduler_class in self.scheduler_classes:
116
scheduler_config = self.get_scheduler_config()
117
scheduler = scheduler_class(**scheduler_config)
118
119
sample = self.dummy_sample
120
residual = 0.1 * sample
121
122
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
123
scheduler.set_timesteps(num_inference_steps)
124
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
125
kwargs["num_inference_steps"] = num_inference_steps
126
127
# copy over dummy past residuals (must be done after set_timesteps)
128
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
129
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
130
131
time_step_0 = scheduler.timesteps[5]
132
time_step_1 = scheduler.timesteps[6]
133
134
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
135
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
136
137
self.assertEqual(output_0.shape, sample.shape)
138
self.assertEqual(output_0.shape, output_1.shape)
139
140
def test_switch(self):
141
# make sure that iterating over schedulers with same config names gives same results
142
# for defaults
143
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
144
sample = self.full_loop(scheduler=scheduler)
145
result_mean = torch.mean(torch.abs(sample))
146
147
assert abs(result_mean.item() - 0.2521) < 1e-3
148
149
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
150
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
151
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
152
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
153
154
sample = self.full_loop(scheduler=scheduler)
155
result_mean = torch.mean(torch.abs(sample))
156
157
assert abs(result_mean.item() - 0.2521) < 1e-3
158
159
def test_timesteps(self):
160
for timesteps in [25, 50, 100, 999, 1000]:
161
self.check_over_configs(num_train_timesteps=timesteps)
162
163
def test_thresholding(self):
164
self.check_over_configs(thresholding=False)
165
for order in [1, 2, 3]:
166
for solver_type in ["bh1", "bh2"]:
167
for threshold in [0.5, 1.0, 2.0]:
168
for prediction_type in ["epsilon", "sample"]:
169
self.check_over_configs(
170
thresholding=True,
171
prediction_type=prediction_type,
172
sample_max_value=threshold,
173
solver_order=order,
174
solver_type=solver_type,
175
)
176
177
def test_prediction_type(self):
178
for prediction_type in ["epsilon", "v_prediction"]:
179
self.check_over_configs(prediction_type=prediction_type)
180
181
def test_solver_order_and_type(self):
182
for solver_type in ["bh1", "bh2"]:
183
for order in [1, 2, 3]:
184
for prediction_type in ["epsilon", "sample"]:
185
self.check_over_configs(
186
solver_order=order,
187
solver_type=solver_type,
188
prediction_type=prediction_type,
189
)
190
sample = self.full_loop(
191
solver_order=order,
192
solver_type=solver_type,
193
prediction_type=prediction_type,
194
)
195
assert not torch.isnan(sample).any(), "Samples have nan numbers"
196
197
def test_lower_order_final(self):
198
self.check_over_configs(lower_order_final=True)
199
self.check_over_configs(lower_order_final=False)
200
201
def test_inference_steps(self):
202
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
203
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
204
205
def test_full_loop_no_noise(self):
206
sample = self.full_loop()
207
result_mean = torch.mean(torch.abs(sample))
208
209
assert abs(result_mean.item() - 0.2521) < 1e-3
210
211
def test_full_loop_with_v_prediction(self):
212
sample = self.full_loop(prediction_type="v_prediction")
213
result_mean = torch.mean(torch.abs(sample))
214
215
assert abs(result_mean.item() - 0.1096) < 1e-3
216
217
def test_fp16_support(self):
218
scheduler_class = self.scheduler_classes[0]
219
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
220
scheduler = scheduler_class(**scheduler_config)
221
222
num_inference_steps = 10
223
model = self.dummy_model()
224
sample = self.dummy_sample_deter.half()
225
scheduler.set_timesteps(num_inference_steps)
226
227
for i, t in enumerate(scheduler.timesteps):
228
residual = model(sample, t)
229
sample = scheduler.step(residual, t, sample).prev_sample
230
231
assert sample.dtype == torch.float16
232
233