Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_utils.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import unittest
17
18
from diffusers import __version__
19
from diffusers.utils import deprecate
20
21
22
class DeprecateTester(unittest.TestCase):
23
higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:])
24
lower_version = "0.0.1"
25
26
def test_deprecate_function_arg(self):
27
kwargs = {"deprecated_arg": 4}
28
29
with self.assertWarns(FutureWarning) as warning:
30
output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs)
31
32
assert output == 4
33
assert (
34
str(warning.warning)
35
== f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}."
36
" message"
37
)
38
39
def test_deprecate_function_arg_tuple(self):
40
kwargs = {"deprecated_arg": 4}
41
42
with self.assertWarns(FutureWarning) as warning:
43
output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs)
44
45
assert output == 4
46
assert (
47
str(warning.warning)
48
== f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}."
49
" message"
50
)
51
52
def test_deprecate_function_args(self):
53
kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8}
54
with self.assertWarns(FutureWarning) as warning:
55
output_1, output_2 = deprecate(
56
("deprecated_arg_1", self.higher_version, "Hey"),
57
("deprecated_arg_2", self.higher_version, "Hey"),
58
take_from=kwargs,
59
)
60
assert output_1 == 4
61
assert output_2 == 8
62
assert (
63
str(warning.warnings[0].message)
64
== "The `deprecated_arg_1` argument is deprecated and will be removed in version"
65
f" {self.higher_version}. Hey"
66
)
67
assert (
68
str(warning.warnings[1].message)
69
== "The `deprecated_arg_2` argument is deprecated and will be removed in version"
70
f" {self.higher_version}. Hey"
71
)
72
73
def test_deprecate_function_incorrect_arg(self):
74
kwargs = {"deprecated_arg": 4}
75
76
with self.assertRaises(TypeError) as error:
77
deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs)
78
79
assert "test_deprecate_function_incorrect_arg in" in str(error.exception)
80
assert "line" in str(error.exception)
81
assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception)
82
83
def test_deprecate_arg_no_kwarg(self):
84
with self.assertWarns(FutureWarning) as warning:
85
deprecate(("deprecated_arg", self.higher_version, "message"))
86
87
assert (
88
str(warning.warning)
89
== f"`deprecated_arg` is deprecated and will be removed in version {self.higher_version}. message"
90
)
91
92
def test_deprecate_args_no_kwarg(self):
93
with self.assertWarns(FutureWarning) as warning:
94
deprecate(
95
("deprecated_arg_1", self.higher_version, "Hey"),
96
("deprecated_arg_2", self.higher_version, "Hey"),
97
)
98
assert (
99
str(warning.warnings[0].message)
100
== f"`deprecated_arg_1` is deprecated and will be removed in version {self.higher_version}. Hey"
101
)
102
assert (
103
str(warning.warnings[1].message)
104
== f"`deprecated_arg_2` is deprecated and will be removed in version {self.higher_version}. Hey"
105
)
106
107
def test_deprecate_class_obj(self):
108
class Args:
109
arg = 5
110
111
with self.assertWarns(FutureWarning) as warning:
112
arg = deprecate(("arg", self.higher_version, "message"), take_from=Args())
113
114
assert arg == 5
115
assert (
116
str(warning.warning)
117
== f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
118
)
119
120
def test_deprecate_class_objs(self):
121
class Args:
122
arg = 5
123
foo = 7
124
125
with self.assertWarns(FutureWarning) as warning:
126
arg_1, arg_2 = deprecate(
127
("arg", self.higher_version, "message"),
128
("foo", self.higher_version, "message"),
129
("does not exist", self.higher_version, "message"),
130
take_from=Args(),
131
)
132
133
assert arg_1 == 5
134
assert arg_2 == 7
135
assert (
136
str(warning.warning)
137
== f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
138
)
139
assert (
140
str(warning.warnings[0].message)
141
== f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message"
142
)
143
assert (
144
str(warning.warnings[1].message)
145
== f"The `foo` attribute is deprecated and will be removed in version {self.higher_version}. message"
146
)
147
148
def test_deprecate_incorrect_version(self):
149
kwargs = {"deprecated_arg": 4}
150
151
with self.assertRaises(ValueError) as error:
152
deprecate(("wrong_arg", self.lower_version, "message"), take_from=kwargs)
153
154
assert (
155
str(error.exception)
156
== "The deprecation tuple ('wrong_arg', '0.0.1', 'message') should be removed since diffusers' version"
157
f" {__version__} is >= {self.lower_version}"
158
)
159
160
def test_deprecate_incorrect_no_standard_warn(self):
161
with self.assertWarns(FutureWarning) as warning:
162
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
163
164
assert str(warning.warning) == "This message is better!!!"
165
166
def test_deprecate_stacklevel(self):
167
with self.assertWarns(FutureWarning) as warning:
168
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
169
assert str(warning.warning) == "This message is better!!!"
170
assert "diffusers/tests/test_utils.py" in warning.filename
171
172