Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_dpm_single.py
1448 views
1
import tempfile
2
3
import torch
4
5
from diffusers import (
6
DEISMultistepScheduler,
7
DPMSolverMultistepScheduler,
8
DPMSolverSinglestepScheduler,
9
UniPCMultistepScheduler,
10
)
11
12
from .test_schedulers import SchedulerCommonTest
13
14
15
class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
16
scheduler_classes = (DPMSolverSinglestepScheduler,)
17
forward_default_kwargs = (("num_inference_steps", 25),)
18
19
def get_scheduler_config(self, **kwargs):
20
config = {
21
"num_train_timesteps": 1000,
22
"beta_start": 0.0001,
23
"beta_end": 0.02,
24
"beta_schedule": "linear",
25
"solver_order": 2,
26
"prediction_type": "epsilon",
27
"thresholding": False,
28
"sample_max_value": 1.0,
29
"algorithm_type": "dpmsolver++",
30
"solver_type": "midpoint",
31
}
32
33
config.update(**kwargs)
34
return config
35
36
def check_over_configs(self, time_step=0, **config):
37
kwargs = dict(self.forward_default_kwargs)
38
num_inference_steps = kwargs.pop("num_inference_steps", None)
39
sample = self.dummy_sample
40
residual = 0.1 * sample
41
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
42
43
for scheduler_class in self.scheduler_classes:
44
scheduler_config = self.get_scheduler_config(**config)
45
scheduler = scheduler_class(**scheduler_config)
46
scheduler.set_timesteps(num_inference_steps)
47
# copy over dummy past residuals
48
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
49
50
with tempfile.TemporaryDirectory() as tmpdirname:
51
scheduler.save_config(tmpdirname)
52
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
53
new_scheduler.set_timesteps(num_inference_steps)
54
# copy over dummy past residuals
55
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
56
57
output, new_output = sample, sample
58
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
59
output = scheduler.step(residual, t, output, **kwargs).prev_sample
60
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
61
62
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
63
64
def test_from_save_pretrained(self):
65
pass
66
67
def check_over_forward(self, time_step=0, **forward_kwargs):
68
kwargs = dict(self.forward_default_kwargs)
69
num_inference_steps = kwargs.pop("num_inference_steps", None)
70
sample = self.dummy_sample
71
residual = 0.1 * sample
72
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
73
74
for scheduler_class in self.scheduler_classes:
75
scheduler_config = self.get_scheduler_config()
76
scheduler = scheduler_class(**scheduler_config)
77
scheduler.set_timesteps(num_inference_steps)
78
79
# copy over dummy past residuals (must be after setting timesteps)
80
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
81
82
with tempfile.TemporaryDirectory() as tmpdirname:
83
scheduler.save_config(tmpdirname)
84
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
85
# copy over dummy past residuals
86
new_scheduler.set_timesteps(num_inference_steps)
87
88
# copy over dummy past residual (must be after setting timesteps)
89
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
90
91
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
92
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
93
94
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
95
96
def full_loop(self, scheduler=None, **config):
97
if scheduler is None:
98
scheduler_class = self.scheduler_classes[0]
99
scheduler_config = self.get_scheduler_config(**config)
100
scheduler = scheduler_class(**scheduler_config)
101
102
scheduler_class = self.scheduler_classes[0]
103
scheduler_config = self.get_scheduler_config(**config)
104
scheduler = scheduler_class(**scheduler_config)
105
106
num_inference_steps = 10
107
model = self.dummy_model()
108
sample = self.dummy_sample_deter
109
scheduler.set_timesteps(num_inference_steps)
110
111
for i, t in enumerate(scheduler.timesteps):
112
residual = model(sample, t)
113
sample = scheduler.step(residual, t, sample).prev_sample
114
115
return sample
116
117
def test_timesteps(self):
118
for timesteps in [25, 50, 100, 999, 1000]:
119
self.check_over_configs(num_train_timesteps=timesteps)
120
121
def test_switch(self):
122
# make sure that iterating over schedulers with same config names gives same results
123
# for defaults
124
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
125
sample = self.full_loop(scheduler=scheduler)
126
result_mean = torch.mean(torch.abs(sample))
127
128
assert abs(result_mean.item() - 0.2791) < 1e-3
129
130
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
131
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
132
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
133
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
134
135
sample = self.full_loop(scheduler=scheduler)
136
result_mean = torch.mean(torch.abs(sample))
137
138
assert abs(result_mean.item() - 0.2791) < 1e-3
139
140
def test_thresholding(self):
141
self.check_over_configs(thresholding=False)
142
for order in [1, 2, 3]:
143
for solver_type in ["midpoint", "heun"]:
144
for threshold in [0.5, 1.0, 2.0]:
145
for prediction_type in ["epsilon", "sample"]:
146
self.check_over_configs(
147
thresholding=True,
148
prediction_type=prediction_type,
149
sample_max_value=threshold,
150
algorithm_type="dpmsolver++",
151
solver_order=order,
152
solver_type=solver_type,
153
)
154
155
def test_prediction_type(self):
156
for prediction_type in ["epsilon", "v_prediction"]:
157
self.check_over_configs(prediction_type=prediction_type)
158
159
def test_solver_order_and_type(self):
160
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
161
for solver_type in ["midpoint", "heun"]:
162
for order in [1, 2, 3]:
163
for prediction_type in ["epsilon", "sample"]:
164
self.check_over_configs(
165
solver_order=order,
166
solver_type=solver_type,
167
prediction_type=prediction_type,
168
algorithm_type=algorithm_type,
169
)
170
sample = self.full_loop(
171
solver_order=order,
172
solver_type=solver_type,
173
prediction_type=prediction_type,
174
algorithm_type=algorithm_type,
175
)
176
assert not torch.isnan(sample).any(), "Samples have nan numbers"
177
178
def test_lower_order_final(self):
179
self.check_over_configs(lower_order_final=True)
180
self.check_over_configs(lower_order_final=False)
181
182
def test_inference_steps(self):
183
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
184
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
185
186
def test_full_loop_no_noise(self):
187
sample = self.full_loop()
188
result_mean = torch.mean(torch.abs(sample))
189
190
assert abs(result_mean.item() - 0.2791) < 1e-3
191
192
def test_full_loop_with_v_prediction(self):
193
sample = self.full_loop(prediction_type="v_prediction")
194
result_mean = torch.mean(torch.abs(sample))
195
196
assert abs(result_mean.item() - 0.1453) < 1e-3
197
198
def test_fp16_support(self):
199
scheduler_class = self.scheduler_classes[0]
200
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
201
scheduler = scheduler_class(**scheduler_config)
202
203
num_inference_steps = 10
204
model = self.dummy_model()
205
sample = self.dummy_sample_deter.half()
206
scheduler.set_timesteps(num_inference_steps)
207
208
for i, t in enumerate(scheduler.timesteps):
209
residual = model(sample, t)
210
sample = scheduler.step(residual, t, sample).prev_sample
211
212
assert sample.dtype == torch.float16
213
214