Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_config.py
1441 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import tempfile
17
import unittest
18
19
from diffusers import (
20
DDIMScheduler,
21
DDPMScheduler,
22
DPMSolverMultistepScheduler,
23
EulerAncestralDiscreteScheduler,
24
EulerDiscreteScheduler,
25
PNDMScheduler,
26
logging,
27
)
28
from diffusers.configuration_utils import ConfigMixin, register_to_config
29
from diffusers.utils.testing_utils import CaptureLogger
30
31
32
class SampleObject(ConfigMixin):
33
config_name = "config.json"
34
35
@register_to_config
36
def __init__(
37
self,
38
a=2,
39
b=5,
40
c=(2, 5),
41
d="for diffusion",
42
e=[1, 3],
43
):
44
pass
45
46
47
class SampleObject2(ConfigMixin):
48
config_name = "config.json"
49
50
@register_to_config
51
def __init__(
52
self,
53
a=2,
54
b=5,
55
c=(2, 5),
56
d="for diffusion",
57
f=[1, 3],
58
):
59
pass
60
61
62
class SampleObject3(ConfigMixin):
63
config_name = "config.json"
64
65
@register_to_config
66
def __init__(
67
self,
68
a=2,
69
b=5,
70
c=(2, 5),
71
d="for diffusion",
72
e=[1, 3],
73
f=[1, 3],
74
):
75
pass
76
77
78
class ConfigTester(unittest.TestCase):
79
def test_load_not_from_mixin(self):
80
with self.assertRaises(ValueError):
81
ConfigMixin.load_config("dummy_path")
82
83
def test_register_to_config(self):
84
obj = SampleObject()
85
config = obj.config
86
assert config["a"] == 2
87
assert config["b"] == 5
88
assert config["c"] == (2, 5)
89
assert config["d"] == "for diffusion"
90
assert config["e"] == [1, 3]
91
92
# init ignore private arguments
93
obj = SampleObject(_name_or_path="lalala")
94
config = obj.config
95
assert config["a"] == 2
96
assert config["b"] == 5
97
assert config["c"] == (2, 5)
98
assert config["d"] == "for diffusion"
99
assert config["e"] == [1, 3]
100
101
# can override default
102
obj = SampleObject(c=6)
103
config = obj.config
104
assert config["a"] == 2
105
assert config["b"] == 5
106
assert config["c"] == 6
107
assert config["d"] == "for diffusion"
108
assert config["e"] == [1, 3]
109
110
# can use positional arguments.
111
obj = SampleObject(1, c=6)
112
config = obj.config
113
assert config["a"] == 1
114
assert config["b"] == 5
115
assert config["c"] == 6
116
assert config["d"] == "for diffusion"
117
assert config["e"] == [1, 3]
118
119
def test_save_load(self):
120
obj = SampleObject()
121
config = obj.config
122
123
assert config["a"] == 2
124
assert config["b"] == 5
125
assert config["c"] == (2, 5)
126
assert config["d"] == "for diffusion"
127
assert config["e"] == [1, 3]
128
129
with tempfile.TemporaryDirectory() as tmpdirname:
130
obj.save_config(tmpdirname)
131
new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
132
new_config = new_obj.config
133
134
# unfreeze configs
135
config = dict(config)
136
new_config = dict(new_config)
137
138
assert config.pop("c") == (2, 5) # instantiated as tuple
139
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
140
assert config == new_config
141
142
def test_load_ddim_from_pndm(self):
143
logger = logging.get_logger("diffusers.configuration_utils")
144
145
with CaptureLogger(logger) as cap_logger:
146
ddim = DDIMScheduler.from_pretrained(
147
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
148
)
149
150
assert ddim.__class__ == DDIMScheduler
151
# no warning should be thrown
152
assert cap_logger.out == ""
153
154
def test_load_euler_from_pndm(self):
155
logger = logging.get_logger("diffusers.configuration_utils")
156
157
with CaptureLogger(logger) as cap_logger:
158
euler = EulerDiscreteScheduler.from_pretrained(
159
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
160
)
161
162
assert euler.__class__ == EulerDiscreteScheduler
163
# no warning should be thrown
164
assert cap_logger.out == ""
165
166
def test_load_euler_ancestral_from_pndm(self):
167
logger = logging.get_logger("diffusers.configuration_utils")
168
169
with CaptureLogger(logger) as cap_logger:
170
euler = EulerAncestralDiscreteScheduler.from_pretrained(
171
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
172
)
173
174
assert euler.__class__ == EulerAncestralDiscreteScheduler
175
# no warning should be thrown
176
assert cap_logger.out == ""
177
178
def test_load_pndm(self):
179
logger = logging.get_logger("diffusers.configuration_utils")
180
181
with CaptureLogger(logger) as cap_logger:
182
pndm = PNDMScheduler.from_pretrained(
183
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
184
)
185
186
assert pndm.__class__ == PNDMScheduler
187
# no warning should be thrown
188
assert cap_logger.out == ""
189
190
def test_overwrite_config_on_load(self):
191
logger = logging.get_logger("diffusers.configuration_utils")
192
193
with CaptureLogger(logger) as cap_logger:
194
ddpm = DDPMScheduler.from_pretrained(
195
"hf-internal-testing/tiny-stable-diffusion-torch",
196
subfolder="scheduler",
197
prediction_type="sample",
198
beta_end=8,
199
)
200
201
with CaptureLogger(logger) as cap_logger_2:
202
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
203
204
assert ddpm.__class__ == DDPMScheduler
205
assert ddpm.config.prediction_type == "sample"
206
assert ddpm.config.beta_end == 8
207
assert ddpm_2.config.beta_start == 88
208
209
# no warning should be thrown
210
assert cap_logger.out == ""
211
assert cap_logger_2.out == ""
212
213
def test_load_dpmsolver(self):
214
logger = logging.get_logger("diffusers.configuration_utils")
215
216
with CaptureLogger(logger) as cap_logger:
217
dpm = DPMSolverMultistepScheduler.from_pretrained(
218
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
219
)
220
221
assert dpm.__class__ == DPMSolverMultistepScheduler
222
# no warning should be thrown
223
assert cap_logger.out == ""
224
225