Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_scheduler_flax.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import inspect
16
import tempfile
17
import unittest
18
from typing import Dict, List, Tuple
19
20
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
21
from diffusers.utils import is_flax_available
22
from diffusers.utils.testing_utils import require_flax
23
24
25
if is_flax_available():
26
import jax
27
import jax.numpy as jnp
28
from jax import random
29
30
jax_device = jax.default_backend()
31
32
33
@require_flax
34
class FlaxSchedulerCommonTest(unittest.TestCase):
35
scheduler_classes = ()
36
forward_default_kwargs = ()
37
38
@property
39
def dummy_sample(self):
40
batch_size = 4
41
num_channels = 3
42
height = 8
43
width = 8
44
45
key1, key2 = random.split(random.PRNGKey(0))
46
sample = random.uniform(key1, (batch_size, num_channels, height, width))
47
48
return sample, key2
49
50
@property
51
def dummy_sample_deter(self):
52
batch_size = 4
53
num_channels = 3
54
height = 8
55
width = 8
56
57
num_elems = batch_size * num_channels * height * width
58
sample = jnp.arange(num_elems)
59
sample = sample.reshape(num_channels, height, width, batch_size)
60
sample = sample / num_elems
61
return jnp.transpose(sample, (3, 0, 1, 2))
62
63
def get_scheduler_config(self):
64
raise NotImplementedError
65
66
def dummy_model(self):
67
def model(sample, t, *args):
68
return sample * t / (t + 1)
69
70
return model
71
72
def check_over_configs(self, time_step=0, **config):
73
kwargs = dict(self.forward_default_kwargs)
74
75
num_inference_steps = kwargs.pop("num_inference_steps", None)
76
77
for scheduler_class in self.scheduler_classes:
78
sample, key = self.dummy_sample
79
residual = 0.1 * sample
80
81
scheduler_config = self.get_scheduler_config(**config)
82
scheduler = scheduler_class(**scheduler_config)
83
state = scheduler.create_state()
84
85
with tempfile.TemporaryDirectory() as tmpdirname:
86
scheduler.save_config(tmpdirname)
87
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
88
89
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
90
state = scheduler.set_timesteps(state, num_inference_steps)
91
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
92
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
93
kwargs["num_inference_steps"] = num_inference_steps
94
95
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
96
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
97
98
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
99
100
def check_over_forward(self, time_step=0, **forward_kwargs):
101
kwargs = dict(self.forward_default_kwargs)
102
kwargs.update(forward_kwargs)
103
104
num_inference_steps = kwargs.pop("num_inference_steps", None)
105
106
for scheduler_class in self.scheduler_classes:
107
sample, key = self.dummy_sample
108
residual = 0.1 * sample
109
110
scheduler_config = self.get_scheduler_config()
111
scheduler = scheduler_class(**scheduler_config)
112
state = scheduler.create_state()
113
114
with tempfile.TemporaryDirectory() as tmpdirname:
115
scheduler.save_config(tmpdirname)
116
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
117
118
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
119
state = scheduler.set_timesteps(state, num_inference_steps)
120
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
121
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
122
kwargs["num_inference_steps"] = num_inference_steps
123
124
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
125
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
126
127
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
128
129
def test_from_save_pretrained(self):
130
kwargs = dict(self.forward_default_kwargs)
131
132
num_inference_steps = kwargs.pop("num_inference_steps", None)
133
134
for scheduler_class in self.scheduler_classes:
135
sample, key = self.dummy_sample
136
residual = 0.1 * sample
137
138
scheduler_config = self.get_scheduler_config()
139
scheduler = scheduler_class(**scheduler_config)
140
state = scheduler.create_state()
141
142
with tempfile.TemporaryDirectory() as tmpdirname:
143
scheduler.save_config(tmpdirname)
144
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
145
146
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
147
state = scheduler.set_timesteps(state, num_inference_steps)
148
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
149
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
150
kwargs["num_inference_steps"] = num_inference_steps
151
152
output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
153
new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
154
155
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
156
157
def test_step_shape(self):
158
kwargs = dict(self.forward_default_kwargs)
159
160
num_inference_steps = kwargs.pop("num_inference_steps", None)
161
162
for scheduler_class in self.scheduler_classes:
163
scheduler_config = self.get_scheduler_config()
164
scheduler = scheduler_class(**scheduler_config)
165
state = scheduler.create_state()
166
167
sample, key = self.dummy_sample
168
residual = 0.1 * sample
169
170
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
171
state = scheduler.set_timesteps(state, num_inference_steps)
172
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
173
kwargs["num_inference_steps"] = num_inference_steps
174
175
output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
176
output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
177
178
self.assertEqual(output_0.shape, sample.shape)
179
self.assertEqual(output_0.shape, output_1.shape)
180
181
def test_scheduler_outputs_equivalence(self):
182
def set_nan_tensor_to_zero(t):
183
return t.at[t != t].set(0)
184
185
def recursive_check(tuple_object, dict_object):
186
if isinstance(tuple_object, (List, Tuple)):
187
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
188
recursive_check(tuple_iterable_value, dict_iterable_value)
189
elif isinstance(tuple_object, Dict):
190
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
191
recursive_check(tuple_iterable_value, dict_iterable_value)
192
elif tuple_object is None:
193
return
194
else:
195
self.assertTrue(
196
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
197
msg=(
198
"Tuple and dict output are not equal. Difference:"
199
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
200
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
201
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
202
),
203
)
204
205
kwargs = dict(self.forward_default_kwargs)
206
num_inference_steps = kwargs.pop("num_inference_steps", None)
207
208
for scheduler_class in self.scheduler_classes:
209
scheduler_config = self.get_scheduler_config()
210
scheduler = scheduler_class(**scheduler_config)
211
state = scheduler.create_state()
212
213
sample, key = self.dummy_sample
214
residual = 0.1 * sample
215
216
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
217
state = scheduler.set_timesteps(state, num_inference_steps)
218
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
219
kwargs["num_inference_steps"] = num_inference_steps
220
221
outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
222
223
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
224
state = scheduler.set_timesteps(state, num_inference_steps)
225
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
226
kwargs["num_inference_steps"] = num_inference_steps
227
228
outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
229
230
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
231
232
def test_deprecated_kwargs(self):
233
for scheduler_class in self.scheduler_classes:
234
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
235
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
236
237
if has_kwarg_in_model_class and not has_deprecated_kwarg:
238
raise ValueError(
239
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
240
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
241
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
242
" [<deprecated_argument>]`"
243
)
244
245
if not has_kwarg_in_model_class and has_deprecated_kwarg:
246
raise ValueError(
247
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
248
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
249
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
250
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
251
)
252
253
254
@require_flax
255
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
256
scheduler_classes = (FlaxDDPMScheduler,)
257
258
def get_scheduler_config(self, **kwargs):
259
config = {
260
"num_train_timesteps": 1000,
261
"beta_start": 0.0001,
262
"beta_end": 0.02,
263
"beta_schedule": "linear",
264
"variance_type": "fixed_small",
265
"clip_sample": True,
266
}
267
268
config.update(**kwargs)
269
return config
270
271
def test_timesteps(self):
272
for timesteps in [1, 5, 100, 1000]:
273
self.check_over_configs(num_train_timesteps=timesteps)
274
275
def test_betas(self):
276
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
277
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
278
279
def test_schedules(self):
280
for schedule in ["linear", "squaredcos_cap_v2"]:
281
self.check_over_configs(beta_schedule=schedule)
282
283
def test_variance_type(self):
284
for variance in ["fixed_small", "fixed_large", "other"]:
285
self.check_over_configs(variance_type=variance)
286
287
def test_clip_sample(self):
288
for clip_sample in [True, False]:
289
self.check_over_configs(clip_sample=clip_sample)
290
291
def test_time_indices(self):
292
for t in [0, 500, 999]:
293
self.check_over_forward(time_step=t)
294
295
def test_variance(self):
296
scheduler_class = self.scheduler_classes[0]
297
scheduler_config = self.get_scheduler_config()
298
scheduler = scheduler_class(**scheduler_config)
299
state = scheduler.create_state()
300
301
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
302
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
303
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
304
305
def test_full_loop_no_noise(self):
306
scheduler_class = self.scheduler_classes[0]
307
scheduler_config = self.get_scheduler_config()
308
scheduler = scheduler_class(**scheduler_config)
309
state = scheduler.create_state()
310
311
num_trained_timesteps = len(scheduler)
312
313
model = self.dummy_model()
314
sample = self.dummy_sample_deter
315
key1, key2 = random.split(random.PRNGKey(0))
316
317
for t in reversed(range(num_trained_timesteps)):
318
# 1. predict noise residual
319
residual = model(sample, t)
320
321
# 2. predict previous mean of sample x_t-1
322
output = scheduler.step(state, residual, t, sample, key1)
323
pred_prev_sample = output.prev_sample
324
state = output.state
325
key1, key2 = random.split(key2)
326
327
# if t > 0:
328
# noise = self.dummy_sample_deter
329
# variance = scheduler.get_variance(t) ** (0.5) * noise
330
#
331
# sample = pred_prev_sample + variance
332
sample = pred_prev_sample
333
334
result_sum = jnp.sum(jnp.abs(sample))
335
result_mean = jnp.mean(jnp.abs(sample))
336
337
if jax_device == "tpu":
338
assert abs(result_sum - 255.0714) < 1e-2
339
assert abs(result_mean - 0.332124) < 1e-3
340
else:
341
assert abs(result_sum - 255.1113) < 1e-2
342
assert abs(result_mean - 0.332176) < 1e-3
343
344
345
@require_flax
346
class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
347
scheduler_classes = (FlaxDDIMScheduler,)
348
forward_default_kwargs = (("num_inference_steps", 50),)
349
350
def get_scheduler_config(self, **kwargs):
351
config = {
352
"num_train_timesteps": 1000,
353
"beta_start": 0.0001,
354
"beta_end": 0.02,
355
"beta_schedule": "linear",
356
}
357
358
config.update(**kwargs)
359
return config
360
361
def full_loop(self, **config):
362
scheduler_class = self.scheduler_classes[0]
363
scheduler_config = self.get_scheduler_config(**config)
364
scheduler = scheduler_class(**scheduler_config)
365
state = scheduler.create_state()
366
key1, key2 = random.split(random.PRNGKey(0))
367
368
num_inference_steps = 10
369
370
model = self.dummy_model()
371
sample = self.dummy_sample_deter
372
373
state = scheduler.set_timesteps(state, num_inference_steps)
374
375
for t in state.timesteps:
376
residual = model(sample, t)
377
output = scheduler.step(state, residual, t, sample)
378
sample = output.prev_sample
379
state = output.state
380
key1, key2 = random.split(key2)
381
382
return sample
383
384
def check_over_configs(self, time_step=0, **config):
385
kwargs = dict(self.forward_default_kwargs)
386
387
num_inference_steps = kwargs.pop("num_inference_steps", None)
388
389
for scheduler_class in self.scheduler_classes:
390
sample, _ = self.dummy_sample
391
residual = 0.1 * sample
392
393
scheduler_config = self.get_scheduler_config(**config)
394
scheduler = scheduler_class(**scheduler_config)
395
state = scheduler.create_state()
396
397
with tempfile.TemporaryDirectory() as tmpdirname:
398
scheduler.save_config(tmpdirname)
399
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
400
401
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
402
state = scheduler.set_timesteps(state, num_inference_steps)
403
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
404
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
405
kwargs["num_inference_steps"] = num_inference_steps
406
407
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
408
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
409
410
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
411
412
def test_from_save_pretrained(self):
413
kwargs = dict(self.forward_default_kwargs)
414
415
num_inference_steps = kwargs.pop("num_inference_steps", None)
416
417
for scheduler_class in self.scheduler_classes:
418
sample, _ = self.dummy_sample
419
residual = 0.1 * sample
420
421
scheduler_config = self.get_scheduler_config()
422
scheduler = scheduler_class(**scheduler_config)
423
state = scheduler.create_state()
424
425
with tempfile.TemporaryDirectory() as tmpdirname:
426
scheduler.save_config(tmpdirname)
427
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
428
429
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
430
state = scheduler.set_timesteps(state, num_inference_steps)
431
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
432
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
433
kwargs["num_inference_steps"] = num_inference_steps
434
435
output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
436
new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
437
438
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
439
440
def check_over_forward(self, time_step=0, **forward_kwargs):
441
kwargs = dict(self.forward_default_kwargs)
442
kwargs.update(forward_kwargs)
443
444
num_inference_steps = kwargs.pop("num_inference_steps", None)
445
446
for scheduler_class in self.scheduler_classes:
447
sample, _ = self.dummy_sample
448
residual = 0.1 * sample
449
450
scheduler_config = self.get_scheduler_config()
451
scheduler = scheduler_class(**scheduler_config)
452
state = scheduler.create_state()
453
454
with tempfile.TemporaryDirectory() as tmpdirname:
455
scheduler.save_config(tmpdirname)
456
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
457
458
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
459
state = scheduler.set_timesteps(state, num_inference_steps)
460
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
461
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
462
kwargs["num_inference_steps"] = num_inference_steps
463
464
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
465
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
466
467
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
468
469
def test_scheduler_outputs_equivalence(self):
470
def set_nan_tensor_to_zero(t):
471
return t.at[t != t].set(0)
472
473
def recursive_check(tuple_object, dict_object):
474
if isinstance(tuple_object, (List, Tuple)):
475
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
476
recursive_check(tuple_iterable_value, dict_iterable_value)
477
elif isinstance(tuple_object, Dict):
478
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
479
recursive_check(tuple_iterable_value, dict_iterable_value)
480
elif tuple_object is None:
481
return
482
else:
483
self.assertTrue(
484
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
485
msg=(
486
"Tuple and dict output are not equal. Difference:"
487
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
488
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
489
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
490
),
491
)
492
493
kwargs = dict(self.forward_default_kwargs)
494
num_inference_steps = kwargs.pop("num_inference_steps", None)
495
496
for scheduler_class in self.scheduler_classes:
497
scheduler_config = self.get_scheduler_config()
498
scheduler = scheduler_class(**scheduler_config)
499
state = scheduler.create_state()
500
501
sample, _ = self.dummy_sample
502
residual = 0.1 * sample
503
504
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
505
state = scheduler.set_timesteps(state, num_inference_steps)
506
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
507
kwargs["num_inference_steps"] = num_inference_steps
508
509
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
510
511
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
512
state = scheduler.set_timesteps(state, num_inference_steps)
513
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
514
kwargs["num_inference_steps"] = num_inference_steps
515
516
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
517
518
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
519
520
def test_step_shape(self):
521
kwargs = dict(self.forward_default_kwargs)
522
523
num_inference_steps = kwargs.pop("num_inference_steps", None)
524
525
for scheduler_class in self.scheduler_classes:
526
scheduler_config = self.get_scheduler_config()
527
scheduler = scheduler_class(**scheduler_config)
528
state = scheduler.create_state()
529
530
sample, _ = self.dummy_sample
531
residual = 0.1 * sample
532
533
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
534
state = scheduler.set_timesteps(state, num_inference_steps)
535
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
536
kwargs["num_inference_steps"] = num_inference_steps
537
538
output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
539
output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
540
541
self.assertEqual(output_0.shape, sample.shape)
542
self.assertEqual(output_0.shape, output_1.shape)
543
544
def test_timesteps(self):
545
for timesteps in [100, 500, 1000]:
546
self.check_over_configs(num_train_timesteps=timesteps)
547
548
def test_steps_offset(self):
549
for steps_offset in [0, 1]:
550
self.check_over_configs(steps_offset=steps_offset)
551
552
scheduler_class = self.scheduler_classes[0]
553
scheduler_config = self.get_scheduler_config(steps_offset=1)
554
scheduler = scheduler_class(**scheduler_config)
555
state = scheduler.create_state()
556
state = scheduler.set_timesteps(state, 5)
557
assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
558
559
def test_betas(self):
560
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
561
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
562
563
def test_schedules(self):
564
for schedule in ["linear", "squaredcos_cap_v2"]:
565
self.check_over_configs(beta_schedule=schedule)
566
567
def test_time_indices(self):
568
for t in [1, 10, 49]:
569
self.check_over_forward(time_step=t)
570
571
def test_inference_steps(self):
572
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
573
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
574
575
def test_variance(self):
576
scheduler_class = self.scheduler_classes[0]
577
scheduler_config = self.get_scheduler_config()
578
scheduler = scheduler_class(**scheduler_config)
579
state = scheduler.create_state()
580
581
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
582
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
583
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
584
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
585
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
586
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
587
588
def test_full_loop_no_noise(self):
589
sample = self.full_loop()
590
591
result_sum = jnp.sum(jnp.abs(sample))
592
result_mean = jnp.mean(jnp.abs(sample))
593
594
assert abs(result_sum - 172.0067) < 1e-2
595
assert abs(result_mean - 0.223967) < 1e-3
596
597
def test_full_loop_with_set_alpha_to_one(self):
598
# We specify different beta, so that the first alpha is 0.99
599
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
600
result_sum = jnp.sum(jnp.abs(sample))
601
result_mean = jnp.mean(jnp.abs(sample))
602
603
if jax_device == "tpu":
604
assert abs(result_sum - 149.8409) < 1e-2
605
assert abs(result_mean - 0.1951) < 1e-3
606
else:
607
assert abs(result_sum - 149.8295) < 1e-2
608
assert abs(result_mean - 0.1951) < 1e-3
609
610
def test_full_loop_with_no_set_alpha_to_one(self):
611
# We specify different beta, so that the first alpha is 0.99
612
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
613
result_sum = jnp.sum(jnp.abs(sample))
614
result_mean = jnp.mean(jnp.abs(sample))
615
616
if jax_device == "tpu":
617
pass
618
# FIXME: both result_sum and result_mean are nan on TPU
619
# assert jnp.isnan(result_sum)
620
# assert jnp.isnan(result_mean)
621
else:
622
assert abs(result_sum - 149.0784) < 1e-2
623
assert abs(result_mean - 0.1941) < 1e-3
624
625
def test_prediction_type(self):
626
for prediction_type in ["epsilon", "sample", "v_prediction"]:
627
self.check_over_configs(prediction_type=prediction_type)
628
629
630
@require_flax
631
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
632
scheduler_classes = (FlaxPNDMScheduler,)
633
forward_default_kwargs = (("num_inference_steps", 50),)
634
635
def get_scheduler_config(self, **kwargs):
636
config = {
637
"num_train_timesteps": 1000,
638
"beta_start": 0.0001,
639
"beta_end": 0.02,
640
"beta_schedule": "linear",
641
}
642
643
config.update(**kwargs)
644
return config
645
646
def check_over_configs(self, time_step=0, **config):
647
kwargs = dict(self.forward_default_kwargs)
648
num_inference_steps = kwargs.pop("num_inference_steps", None)
649
sample, _ = self.dummy_sample
650
residual = 0.1 * sample
651
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
652
653
for scheduler_class in self.scheduler_classes:
654
scheduler_config = self.get_scheduler_config(**config)
655
scheduler = scheduler_class(**scheduler_config)
656
state = scheduler.create_state()
657
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
658
# copy over dummy past residuals
659
state = state.replace(ets=dummy_past_residuals[:])
660
661
with tempfile.TemporaryDirectory() as tmpdirname:
662
scheduler.save_config(tmpdirname)
663
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
664
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
665
# copy over dummy past residuals
666
new_state = new_state.replace(ets=dummy_past_residuals[:])
667
668
(prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
669
(new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
670
671
assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
672
673
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
674
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
675
676
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
677
678
def test_from_save_pretrained(self):
679
pass
680
681
def test_scheduler_outputs_equivalence(self):
682
def set_nan_tensor_to_zero(t):
683
return t.at[t != t].set(0)
684
685
def recursive_check(tuple_object, dict_object):
686
if isinstance(tuple_object, (List, Tuple)):
687
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
688
recursive_check(tuple_iterable_value, dict_iterable_value)
689
elif isinstance(tuple_object, Dict):
690
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
691
recursive_check(tuple_iterable_value, dict_iterable_value)
692
elif tuple_object is None:
693
return
694
else:
695
self.assertTrue(
696
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
697
msg=(
698
"Tuple and dict output are not equal. Difference:"
699
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
700
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
701
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
702
),
703
)
704
705
kwargs = dict(self.forward_default_kwargs)
706
num_inference_steps = kwargs.pop("num_inference_steps", None)
707
708
for scheduler_class in self.scheduler_classes:
709
scheduler_config = self.get_scheduler_config()
710
scheduler = scheduler_class(**scheduler_config)
711
state = scheduler.create_state()
712
713
sample, _ = self.dummy_sample
714
residual = 0.1 * sample
715
716
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
717
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
718
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
719
kwargs["num_inference_steps"] = num_inference_steps
720
721
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
722
723
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
724
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
725
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
726
kwargs["num_inference_steps"] = num_inference_steps
727
728
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
729
730
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
731
732
def check_over_forward(self, time_step=0, **forward_kwargs):
733
kwargs = dict(self.forward_default_kwargs)
734
num_inference_steps = kwargs.pop("num_inference_steps", None)
735
sample, _ = self.dummy_sample
736
residual = 0.1 * sample
737
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
738
739
for scheduler_class in self.scheduler_classes:
740
scheduler_config = self.get_scheduler_config()
741
scheduler = scheduler_class(**scheduler_config)
742
state = scheduler.create_state()
743
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
744
745
# copy over dummy past residuals (must be after setting timesteps)
746
scheduler.ets = dummy_past_residuals[:]
747
748
with tempfile.TemporaryDirectory() as tmpdirname:
749
scheduler.save_config(tmpdirname)
750
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
751
# copy over dummy past residuals
752
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
753
754
# copy over dummy past residual (must be after setting timesteps)
755
new_state.replace(ets=dummy_past_residuals[:])
756
757
output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
758
new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
759
760
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
761
762
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
763
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
764
765
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
766
767
def full_loop(self, **config):
768
scheduler_class = self.scheduler_classes[0]
769
scheduler_config = self.get_scheduler_config(**config)
770
scheduler = scheduler_class(**scheduler_config)
771
state = scheduler.create_state()
772
773
num_inference_steps = 10
774
model = self.dummy_model()
775
sample = self.dummy_sample_deter
776
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
777
778
for i, t in enumerate(state.prk_timesteps):
779
residual = model(sample, t)
780
sample, state = scheduler.step_prk(state, residual, t, sample)
781
782
for i, t in enumerate(state.plms_timesteps):
783
residual = model(sample, t)
784
sample, state = scheduler.step_plms(state, residual, t, sample)
785
786
return sample
787
788
def test_step_shape(self):
789
kwargs = dict(self.forward_default_kwargs)
790
791
num_inference_steps = kwargs.pop("num_inference_steps", None)
792
793
for scheduler_class in self.scheduler_classes:
794
scheduler_config = self.get_scheduler_config()
795
scheduler = scheduler_class(**scheduler_config)
796
state = scheduler.create_state()
797
798
sample, _ = self.dummy_sample
799
residual = 0.1 * sample
800
801
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
802
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
803
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
804
kwargs["num_inference_steps"] = num_inference_steps
805
806
# copy over dummy past residuals (must be done after set_timesteps)
807
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
808
state = state.replace(ets=dummy_past_residuals[:])
809
810
output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
811
output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
812
813
self.assertEqual(output_0.shape, sample.shape)
814
self.assertEqual(output_0.shape, output_1.shape)
815
816
output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
817
output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
818
819
self.assertEqual(output_0.shape, sample.shape)
820
self.assertEqual(output_0.shape, output_1.shape)
821
822
def test_timesteps(self):
823
for timesteps in [100, 1000]:
824
self.check_over_configs(num_train_timesteps=timesteps)
825
826
def test_steps_offset(self):
827
for steps_offset in [0, 1]:
828
self.check_over_configs(steps_offset=steps_offset)
829
830
scheduler_class = self.scheduler_classes[0]
831
scheduler_config = self.get_scheduler_config(steps_offset=1)
832
scheduler = scheduler_class(**scheduler_config)
833
state = scheduler.create_state()
834
state = scheduler.set_timesteps(state, 10, shape=())
835
assert jnp.equal(
836
state.timesteps,
837
jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
838
).all()
839
840
def test_betas(self):
841
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
842
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
843
844
def test_schedules(self):
845
for schedule in ["linear", "squaredcos_cap_v2"]:
846
self.check_over_configs(beta_schedule=schedule)
847
848
def test_time_indices(self):
849
for t in [1, 5, 10]:
850
self.check_over_forward(time_step=t)
851
852
def test_inference_steps(self):
853
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
854
self.check_over_forward(num_inference_steps=num_inference_steps)
855
856
def test_pow_of_3_inference_steps(self):
857
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
858
num_inference_steps = 27
859
860
for scheduler_class in self.scheduler_classes:
861
sample, _ = self.dummy_sample
862
residual = 0.1 * sample
863
864
scheduler_config = self.get_scheduler_config()
865
scheduler = scheduler_class(**scheduler_config)
866
state = scheduler.create_state()
867
868
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
869
870
# before power of 3 fix, would error on first step, so we only need to do two
871
for i, t in enumerate(state.prk_timesteps[:2]):
872
sample, state = scheduler.step_prk(state, residual, t, sample)
873
874
def test_inference_plms_no_past_residuals(self):
875
with self.assertRaises(ValueError):
876
scheduler_class = self.scheduler_classes[0]
877
scheduler_config = self.get_scheduler_config()
878
scheduler = scheduler_class(**scheduler_config)
879
state = scheduler.create_state()
880
881
scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
882
883
def test_full_loop_no_noise(self):
884
sample = self.full_loop()
885
result_sum = jnp.sum(jnp.abs(sample))
886
result_mean = jnp.mean(jnp.abs(sample))
887
888
if jax_device == "tpu":
889
assert abs(result_sum - 198.1275) < 1e-2
890
assert abs(result_mean - 0.2580) < 1e-3
891
else:
892
assert abs(result_sum - 198.1318) < 1e-2
893
assert abs(result_mean - 0.2580) < 1e-3
894
895
def test_full_loop_with_set_alpha_to_one(self):
896
# We specify different beta, so that the first alpha is 0.99
897
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
898
result_sum = jnp.sum(jnp.abs(sample))
899
result_mean = jnp.mean(jnp.abs(sample))
900
901
if jax_device == "tpu":
902
assert abs(result_sum - 186.83226) < 1e-2
903
assert abs(result_mean - 0.24327) < 1e-3
904
else:
905
assert abs(result_sum - 186.9466) < 1e-2
906
assert abs(result_mean - 0.24342) < 1e-3
907
908
def test_full_loop_with_no_set_alpha_to_one(self):
909
# We specify different beta, so that the first alpha is 0.99
910
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
911
result_sum = jnp.sum(jnp.abs(sample))
912
result_mean = jnp.mean(jnp.abs(sample))
913
914
if jax_device == "tpu":
915
assert abs(result_sum - 186.83226) < 1e-2
916
assert abs(result_mean - 0.24327) < 1e-3
917
else:
918
assert abs(result_sum - 186.9482) < 1e-2
919
assert abs(result_mean - 0.2434) < 1e-3
920
921