Path: blob/main/tests/test_modeling_common_flax.py
1440 views
import inspect12from diffusers.utils import is_flax_available3from diffusers.utils.testing_utils import require_flax456if is_flax_available():7import jax8910@require_flax11class FlaxModelTesterMixin:12def test_output(self):13init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()1415model = self.model_class(**init_dict)16variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])17jax.lax.stop_gradient(variables)1819output = model.apply(variables, inputs_dict["sample"])2021if isinstance(output, dict):22output = output.sample2324self.assertIsNotNone(output)25expected_shape = inputs_dict["sample"].shape26self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")2728def test_forward_with_norm_groups(self):29init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()3031init_dict["norm_num_groups"] = 1632init_dict["block_out_channels"] = (16, 32)3334model = self.model_class(**init_dict)35variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])36jax.lax.stop_gradient(variables)3738output = model.apply(variables, inputs_dict["sample"])3940if isinstance(output, dict):41output = output.sample4243self.assertIsNotNone(output)44expected_shape = inputs_dict["sample"].shape45self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")4647def test_deprecated_kwargs(self):48has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters49has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 05051if has_kwarg_in_model_class and not has_deprecated_kwarg:52raise ValueError(53f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"54" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"55" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="56" [<deprecated_argument>]`"57)5859if not has_kwarg_in_model_class and has_deprecated_kwarg:60raise ValueError(61f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"62" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"63f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"64" from `_deprecated_kwargs = [<deprecated_argument>]`"65)666768