Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/schedulers/test_schedulers.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import inspect
16
import json
17
import os
18
import tempfile
19
import unittest
20
from typing import Dict, List, Tuple
21
22
import numpy as np
23
import torch
24
25
import diffusers
26
from diffusers import (
27
EulerAncestralDiscreteScheduler,
28
EulerDiscreteScheduler,
29
IPNDMScheduler,
30
LMSDiscreteScheduler,
31
VQDiffusionScheduler,
32
logging,
33
)
34
from diffusers.configuration_utils import ConfigMixin, register_to_config
35
from diffusers.schedulers.scheduling_utils import SchedulerMixin
36
from diffusers.utils import torch_device
37
from diffusers.utils.testing_utils import CaptureLogger
38
39
40
torch.backends.cuda.matmul.allow_tf32 = False
41
42
43
class SchedulerObject(SchedulerMixin, ConfigMixin):
44
config_name = "config.json"
45
46
@register_to_config
47
def __init__(
48
self,
49
a=2,
50
b=5,
51
c=(2, 5),
52
d="for diffusion",
53
e=[1, 3],
54
):
55
pass
56
57
58
class SchedulerObject2(SchedulerMixin, ConfigMixin):
59
config_name = "config.json"
60
61
@register_to_config
62
def __init__(
63
self,
64
a=2,
65
b=5,
66
c=(2, 5),
67
d="for diffusion",
68
f=[1, 3],
69
):
70
pass
71
72
73
class SchedulerObject3(SchedulerMixin, ConfigMixin):
74
config_name = "config.json"
75
76
@register_to_config
77
def __init__(
78
self,
79
a=2,
80
b=5,
81
c=(2, 5),
82
d="for diffusion",
83
e=[1, 3],
84
f=[1, 3],
85
):
86
pass
87
88
89
class SchedulerBaseTests(unittest.TestCase):
90
def test_save_load_from_different_config(self):
91
obj = SchedulerObject()
92
93
# mock add obj class to `diffusers`
94
setattr(diffusers, "SchedulerObject", SchedulerObject)
95
logger = logging.get_logger("diffusers.configuration_utils")
96
97
with tempfile.TemporaryDirectory() as tmpdirname:
98
obj.save_config(tmpdirname)
99
with CaptureLogger(logger) as cap_logger_1:
100
config = SchedulerObject2.load_config(tmpdirname)
101
new_obj_1 = SchedulerObject2.from_config(config)
102
103
# now save a config parameter that is not expected
104
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
105
data = json.load(f)
106
data["unexpected"] = True
107
108
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
109
json.dump(data, f)
110
111
with CaptureLogger(logger) as cap_logger_2:
112
config = SchedulerObject.load_config(tmpdirname)
113
new_obj_2 = SchedulerObject.from_config(config)
114
115
with CaptureLogger(logger) as cap_logger_3:
116
config = SchedulerObject2.load_config(tmpdirname)
117
new_obj_3 = SchedulerObject2.from_config(config)
118
119
assert new_obj_1.__class__ == SchedulerObject2
120
assert new_obj_2.__class__ == SchedulerObject
121
assert new_obj_3.__class__ == SchedulerObject2
122
123
assert cap_logger_1.out == ""
124
assert (
125
cap_logger_2.out
126
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
127
" will"
128
" be ignored. Please verify your config.json configuration file.\n"
129
)
130
assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
131
132
def test_save_load_compatible_schedulers(self):
133
SchedulerObject2._compatibles = ["SchedulerObject"]
134
SchedulerObject._compatibles = ["SchedulerObject2"]
135
136
obj = SchedulerObject()
137
138
# mock add obj class to `diffusers`
139
setattr(diffusers, "SchedulerObject", SchedulerObject)
140
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
141
logger = logging.get_logger("diffusers.configuration_utils")
142
143
with tempfile.TemporaryDirectory() as tmpdirname:
144
obj.save_config(tmpdirname)
145
146
# now save a config parameter that is expected by another class, but not origin class
147
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
148
data = json.load(f)
149
data["f"] = [0, 0]
150
data["unexpected"] = True
151
152
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
153
json.dump(data, f)
154
155
with CaptureLogger(logger) as cap_logger:
156
config = SchedulerObject.load_config(tmpdirname)
157
new_obj = SchedulerObject.from_config(config)
158
159
assert new_obj.__class__ == SchedulerObject
160
161
assert (
162
cap_logger.out
163
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
164
" will"
165
" be ignored. Please verify your config.json configuration file.\n"
166
)
167
168
def test_save_load_from_different_config_comp_schedulers(self):
169
SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
170
SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
171
SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
172
173
obj = SchedulerObject()
174
175
# mock add obj class to `diffusers`
176
setattr(diffusers, "SchedulerObject", SchedulerObject)
177
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
178
setattr(diffusers, "SchedulerObject3", SchedulerObject3)
179
logger = logging.get_logger("diffusers.configuration_utils")
180
logger.setLevel(diffusers.logging.INFO)
181
182
with tempfile.TemporaryDirectory() as tmpdirname:
183
obj.save_config(tmpdirname)
184
185
with CaptureLogger(logger) as cap_logger_1:
186
config = SchedulerObject.load_config(tmpdirname)
187
new_obj_1 = SchedulerObject.from_config(config)
188
189
with CaptureLogger(logger) as cap_logger_2:
190
config = SchedulerObject2.load_config(tmpdirname)
191
new_obj_2 = SchedulerObject2.from_config(config)
192
193
with CaptureLogger(logger) as cap_logger_3:
194
config = SchedulerObject3.load_config(tmpdirname)
195
new_obj_3 = SchedulerObject3.from_config(config)
196
197
assert new_obj_1.__class__ == SchedulerObject
198
assert new_obj_2.__class__ == SchedulerObject2
199
assert new_obj_3.__class__ == SchedulerObject3
200
201
assert cap_logger_1.out == ""
202
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
203
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
204
205
206
class SchedulerCommonTest(unittest.TestCase):
207
scheduler_classes = ()
208
forward_default_kwargs = ()
209
210
@property
211
def dummy_sample(self):
212
batch_size = 4
213
num_channels = 3
214
height = 8
215
width = 8
216
217
sample = torch.rand((batch_size, num_channels, height, width))
218
219
return sample
220
221
@property
222
def dummy_sample_deter(self):
223
batch_size = 4
224
num_channels = 3
225
height = 8
226
width = 8
227
228
num_elems = batch_size * num_channels * height * width
229
sample = torch.arange(num_elems)
230
sample = sample.reshape(num_channels, height, width, batch_size)
231
sample = sample / num_elems
232
sample = sample.permute(3, 0, 1, 2)
233
234
return sample
235
236
def get_scheduler_config(self):
237
raise NotImplementedError
238
239
def dummy_model(self):
240
def model(sample, t, *args):
241
return sample * t / (t + 1)
242
243
return model
244
245
def check_over_configs(self, time_step=0, **config):
246
kwargs = dict(self.forward_default_kwargs)
247
248
num_inference_steps = kwargs.pop("num_inference_steps", None)
249
250
for scheduler_class in self.scheduler_classes:
251
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
252
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
253
time_step = float(time_step)
254
255
scheduler_config = self.get_scheduler_config(**config)
256
scheduler = scheduler_class(**scheduler_config)
257
258
if scheduler_class == VQDiffusionScheduler:
259
num_vec_classes = scheduler_config["num_vec_classes"]
260
sample = self.dummy_sample(num_vec_classes)
261
model = self.dummy_model(num_vec_classes)
262
residual = model(sample, time_step)
263
else:
264
sample = self.dummy_sample
265
residual = 0.1 * sample
266
267
with tempfile.TemporaryDirectory() as tmpdirname:
268
scheduler.save_config(tmpdirname)
269
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
270
271
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
272
scheduler.set_timesteps(num_inference_steps)
273
new_scheduler.set_timesteps(num_inference_steps)
274
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
275
kwargs["num_inference_steps"] = num_inference_steps
276
277
# Make sure `scale_model_input` is invoked to prevent a warning
278
if scheduler_class != VQDiffusionScheduler:
279
_ = scheduler.scale_model_input(sample, 0)
280
_ = new_scheduler.scale_model_input(sample, 0)
281
282
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
283
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
284
kwargs["generator"] = torch.manual_seed(0)
285
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
286
287
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
288
kwargs["generator"] = torch.manual_seed(0)
289
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
290
291
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
292
293
def check_over_forward(self, time_step=0, **forward_kwargs):
294
kwargs = dict(self.forward_default_kwargs)
295
kwargs.update(forward_kwargs)
296
297
num_inference_steps = kwargs.pop("num_inference_steps", None)
298
299
for scheduler_class in self.scheduler_classes:
300
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
301
time_step = float(time_step)
302
303
scheduler_config = self.get_scheduler_config()
304
scheduler = scheduler_class(**scheduler_config)
305
306
if scheduler_class == VQDiffusionScheduler:
307
num_vec_classes = scheduler_config["num_vec_classes"]
308
sample = self.dummy_sample(num_vec_classes)
309
model = self.dummy_model(num_vec_classes)
310
residual = model(sample, time_step)
311
else:
312
sample = self.dummy_sample
313
residual = 0.1 * sample
314
315
with tempfile.TemporaryDirectory() as tmpdirname:
316
scheduler.save_config(tmpdirname)
317
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
318
319
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
320
scheduler.set_timesteps(num_inference_steps)
321
new_scheduler.set_timesteps(num_inference_steps)
322
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
323
kwargs["num_inference_steps"] = num_inference_steps
324
325
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
326
kwargs["generator"] = torch.manual_seed(0)
327
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
328
329
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
330
kwargs["generator"] = torch.manual_seed(0)
331
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
332
333
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
334
335
def test_from_save_pretrained(self):
336
kwargs = dict(self.forward_default_kwargs)
337
338
num_inference_steps = kwargs.pop("num_inference_steps", None)
339
340
for scheduler_class in self.scheduler_classes:
341
timestep = 1
342
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
343
timestep = float(timestep)
344
345
scheduler_config = self.get_scheduler_config()
346
scheduler = scheduler_class(**scheduler_config)
347
348
if scheduler_class == VQDiffusionScheduler:
349
num_vec_classes = scheduler_config["num_vec_classes"]
350
sample = self.dummy_sample(num_vec_classes)
351
model = self.dummy_model(num_vec_classes)
352
residual = model(sample, timestep)
353
else:
354
sample = self.dummy_sample
355
residual = 0.1 * sample
356
357
with tempfile.TemporaryDirectory() as tmpdirname:
358
scheduler.save_config(tmpdirname)
359
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
360
361
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
362
scheduler.set_timesteps(num_inference_steps)
363
new_scheduler.set_timesteps(num_inference_steps)
364
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
365
kwargs["num_inference_steps"] = num_inference_steps
366
367
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
368
kwargs["generator"] = torch.manual_seed(0)
369
output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample
370
371
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
372
kwargs["generator"] = torch.manual_seed(0)
373
new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample
374
375
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
376
377
def test_compatibles(self):
378
for scheduler_class in self.scheduler_classes:
379
scheduler_config = self.get_scheduler_config()
380
381
scheduler = scheduler_class(**scheduler_config)
382
383
assert all(c is not None for c in scheduler.compatibles)
384
385
for comp_scheduler_cls in scheduler.compatibles:
386
comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
387
assert comp_scheduler is not None
388
389
new_scheduler = scheduler_class.from_config(comp_scheduler.config)
390
391
new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
392
scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
393
394
# make sure that configs are essentially identical
395
assert new_scheduler_config == dict(scheduler.config)
396
397
# make sure that only differences are for configs that are not in init
398
init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
399
assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
400
401
def test_from_pretrained(self):
402
for scheduler_class in self.scheduler_classes:
403
scheduler_config = self.get_scheduler_config()
404
405
scheduler = scheduler_class(**scheduler_config)
406
407
with tempfile.TemporaryDirectory() as tmpdirname:
408
scheduler.save_pretrained(tmpdirname)
409
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
410
411
assert scheduler.config == new_scheduler.config
412
413
def test_step_shape(self):
414
kwargs = dict(self.forward_default_kwargs)
415
416
num_inference_steps = kwargs.pop("num_inference_steps", None)
417
418
timestep_0 = 0
419
timestep_1 = 1
420
421
for scheduler_class in self.scheduler_classes:
422
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
423
timestep_0 = float(timestep_0)
424
timestep_1 = float(timestep_1)
425
426
scheduler_config = self.get_scheduler_config()
427
scheduler = scheduler_class(**scheduler_config)
428
429
if scheduler_class == VQDiffusionScheduler:
430
num_vec_classes = scheduler_config["num_vec_classes"]
431
sample = self.dummy_sample(num_vec_classes)
432
model = self.dummy_model(num_vec_classes)
433
residual = model(sample, timestep_0)
434
else:
435
sample = self.dummy_sample
436
residual = 0.1 * sample
437
438
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
439
scheduler.set_timesteps(num_inference_steps)
440
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
441
kwargs["num_inference_steps"] = num_inference_steps
442
443
output_0 = scheduler.step(residual, timestep_0, sample, **kwargs).prev_sample
444
output_1 = scheduler.step(residual, timestep_1, sample, **kwargs).prev_sample
445
446
self.assertEqual(output_0.shape, sample.shape)
447
self.assertEqual(output_0.shape, output_1.shape)
448
449
def test_scheduler_outputs_equivalence(self):
450
def set_nan_tensor_to_zero(t):
451
t[t != t] = 0
452
return t
453
454
def recursive_check(tuple_object, dict_object):
455
if isinstance(tuple_object, (List, Tuple)):
456
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
457
recursive_check(tuple_iterable_value, dict_iterable_value)
458
elif isinstance(tuple_object, Dict):
459
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
460
recursive_check(tuple_iterable_value, dict_iterable_value)
461
elif tuple_object is None:
462
return
463
else:
464
self.assertTrue(
465
torch.allclose(
466
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
467
),
468
msg=(
469
"Tuple and dict output are not equal. Difference:"
470
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
471
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
472
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
473
),
474
)
475
476
kwargs = dict(self.forward_default_kwargs)
477
num_inference_steps = kwargs.pop("num_inference_steps", 50)
478
479
timestep = 0
480
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
481
timestep = 1
482
483
for scheduler_class in self.scheduler_classes:
484
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
485
timestep = float(timestep)
486
487
scheduler_config = self.get_scheduler_config()
488
scheduler = scheduler_class(**scheduler_config)
489
490
if scheduler_class == VQDiffusionScheduler:
491
num_vec_classes = scheduler_config["num_vec_classes"]
492
sample = self.dummy_sample(num_vec_classes)
493
model = self.dummy_model(num_vec_classes)
494
residual = model(sample, timestep)
495
else:
496
sample = self.dummy_sample
497
residual = 0.1 * sample
498
499
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
500
scheduler.set_timesteps(num_inference_steps)
501
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
502
kwargs["num_inference_steps"] = num_inference_steps
503
504
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
505
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
506
kwargs["generator"] = torch.manual_seed(0)
507
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
508
509
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
510
scheduler.set_timesteps(num_inference_steps)
511
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
512
kwargs["num_inference_steps"] = num_inference_steps
513
514
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
515
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
516
kwargs["generator"] = torch.manual_seed(0)
517
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
518
519
recursive_check(outputs_tuple, outputs_dict)
520
521
def test_scheduler_public_api(self):
522
for scheduler_class in self.scheduler_classes:
523
scheduler_config = self.get_scheduler_config()
524
scheduler = scheduler_class(**scheduler_config)
525
526
if scheduler_class != VQDiffusionScheduler:
527
self.assertTrue(
528
hasattr(scheduler, "init_noise_sigma"),
529
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
530
)
531
self.assertTrue(
532
hasattr(scheduler, "scale_model_input"),
533
(
534
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
535
" timestep)`"
536
),
537
)
538
self.assertTrue(
539
hasattr(scheduler, "step"),
540
f"{scheduler_class} does not implement a required class method `step(...)`",
541
)
542
543
if scheduler_class != VQDiffusionScheduler:
544
sample = self.dummy_sample
545
scaled_sample = scheduler.scale_model_input(sample, 0.0)
546
self.assertEqual(sample.shape, scaled_sample.shape)
547
548
def test_add_noise_device(self):
549
for scheduler_class in self.scheduler_classes:
550
if scheduler_class == IPNDMScheduler:
551
continue
552
scheduler_config = self.get_scheduler_config()
553
scheduler = scheduler_class(**scheduler_config)
554
scheduler.set_timesteps(100)
555
556
sample = self.dummy_sample.to(torch_device)
557
scaled_sample = scheduler.scale_model_input(sample, 0.0)
558
self.assertEqual(sample.shape, scaled_sample.shape)
559
560
noise = torch.randn_like(scaled_sample).to(torch_device)
561
t = scheduler.timesteps[5][None]
562
noised = scheduler.add_noise(scaled_sample, noise, t)
563
self.assertEqual(noised.shape, scaled_sample.shape)
564
565
def test_deprecated_kwargs(self):
566
for scheduler_class in self.scheduler_classes:
567
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
568
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
569
570
if has_kwarg_in_model_class and not has_deprecated_kwarg:
571
raise ValueError(
572
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
573
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
574
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
575
" [<deprecated_argument>]`"
576
)
577
578
if not has_kwarg_in_model_class and has_deprecated_kwarg:
579
raise ValueError(
580
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
581
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
582
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
583
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
584
)
585
586
def test_trained_betas(self):
587
for scheduler_class in self.scheduler_classes:
588
if scheduler_class == VQDiffusionScheduler:
589
continue
590
591
scheduler_config = self.get_scheduler_config()
592
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.1, 0.3]))
593
594
with tempfile.TemporaryDirectory() as tmpdirname:
595
scheduler.save_pretrained(tmpdirname)
596
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
597
598
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
599
600