Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/models/test_models_vae_flax.py
1448 views
1
import unittest
2
3
from diffusers import FlaxAutoencoderKL
4
from diffusers.utils import is_flax_available
5
from diffusers.utils.testing_utils import require_flax
6
7
from ..test_modeling_common_flax import FlaxModelTesterMixin
8
9
10
if is_flax_available():
11
import jax
12
13
14
@require_flax
15
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
16
model_class = FlaxAutoencoderKL
17
18
@property
19
def dummy_input(self):
20
batch_size = 4
21
num_channels = 3
22
sizes = (32, 32)
23
24
prng_key = jax.random.PRNGKey(0)
25
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
26
27
return {"sample": image, "prng_key": prng_key}
28
29
def prepare_init_args_and_inputs_for_common(self):
30
init_dict = {
31
"block_out_channels": [32, 64],
32
"in_channels": 3,
33
"out_channels": 3,
34
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
35
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
36
"latent_channels": 4,
37
}
38
inputs_dict = self.dummy_input
39
return init_dict, inputs_dict
40
41