Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_modeling_common_flax.py
1440 views
1
import inspect
2
3
from diffusers.utils import is_flax_available
4
from diffusers.utils.testing_utils import require_flax
5
6
7
if is_flax_available():
8
import jax
9
10
11
@require_flax
12
class FlaxModelTesterMixin:
13
def test_output(self):
14
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15
16
model = self.model_class(**init_dict)
17
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
18
jax.lax.stop_gradient(variables)
19
20
output = model.apply(variables, inputs_dict["sample"])
21
22
if isinstance(output, dict):
23
output = output.sample
24
25
self.assertIsNotNone(output)
26
expected_shape = inputs_dict["sample"].shape
27
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
28
29
def test_forward_with_norm_groups(self):
30
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
31
32
init_dict["norm_num_groups"] = 16
33
init_dict["block_out_channels"] = (16, 32)
34
35
model = self.model_class(**init_dict)
36
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
37
jax.lax.stop_gradient(variables)
38
39
output = model.apply(variables, inputs_dict["sample"])
40
41
if isinstance(output, dict):
42
output = output.sample
43
44
self.assertIsNotNone(output)
45
expected_shape = inputs_dict["sample"].shape
46
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
47
48
def test_deprecated_kwargs(self):
49
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
50
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
51
52
if has_kwarg_in_model_class and not has_deprecated_kwarg:
53
raise ValueError(
54
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
55
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
56
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
57
" [<deprecated_argument>]`"
58
)
59
60
if not has_kwarg_in_model_class and has_deprecated_kwarg:
61
raise ValueError(
62
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
63
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
64
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
65
" from `_deprecated_kwargs = [<deprecated_argument>]`"
66
)
67
68