Path: blob/main/tests/models/test_models_vae_flax.py
1448 views
import unittest12from diffusers import FlaxAutoencoderKL3from diffusers.utils import is_flax_available4from diffusers.utils.testing_utils import require_flax56from ..test_modeling_common_flax import FlaxModelTesterMixin789if is_flax_available():10import jax111213@require_flax14class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):15model_class = FlaxAutoencoderKL1617@property18def dummy_input(self):19batch_size = 420num_channels = 321sizes = (32, 32)2223prng_key = jax.random.PRNGKey(0)24image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))2526return {"sample": image, "prng_key": prng_key}2728def prepare_init_args_and_inputs_for_common(self):29init_dict = {30"block_out_channels": [32, 64],31"in_channels": 3,32"out_channels": 3,33"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],34"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],35"latent_channels": 4,36}37inputs_dict = self.dummy_input38return init_dict, inputs_dict394041