Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_pipelines_flax.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import os
17
import tempfile
18
import unittest
19
20
import numpy as np
21
22
from diffusers.utils import is_flax_available
23
from diffusers.utils.testing_utils import require_flax, slow
24
25
26
if is_flax_available():
27
import jax
28
import jax.numpy as jnp
29
from flax.jax_utils import replicate
30
from flax.training.common_utils import shard
31
from jax import pmap
32
33
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
34
35
36
@require_flax
37
class DownloadTests(unittest.TestCase):
38
def test_download_only_pytorch(self):
39
with tempfile.TemporaryDirectory() as tmpdirname:
40
# pipeline has Flax weights
41
_ = FlaxDiffusionPipeline.from_pretrained(
42
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
43
)
44
45
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
46
files = [item for sublist in all_root_files for item in sublist]
47
48
# None of the downloaded files should be a PyTorch file even if we have some here:
49
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
50
assert not any(f.endswith(".bin") for f in files)
51
52
53
@slow
54
@require_flax
55
class FlaxPipelineTests(unittest.TestCase):
56
def test_dummy_all_tpus(self):
57
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
58
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
59
)
60
61
prompt = (
62
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
63
" field, close up, split lighting, cinematic"
64
)
65
66
prng_seed = jax.random.PRNGKey(0)
67
num_inference_steps = 4
68
69
num_samples = jax.device_count()
70
prompt = num_samples * [prompt]
71
prompt_ids = pipeline.prepare_inputs(prompt)
72
73
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
74
75
# shard inputs and rng
76
params = replicate(params)
77
prng_seed = jax.random.split(prng_seed, num_samples)
78
prompt_ids = shard(prompt_ids)
79
80
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
81
82
assert images.shape == (num_samples, 1, 64, 64, 3)
83
if jax.device_count() == 8:
84
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
85
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
86
87
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
88
89
assert len(images_pil) == num_samples
90
91
def test_stable_diffusion_v1_4(self):
92
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
93
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
94
)
95
96
prompt = (
97
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
98
" field, close up, split lighting, cinematic"
99
)
100
101
prng_seed = jax.random.PRNGKey(0)
102
num_inference_steps = 50
103
104
num_samples = jax.device_count()
105
prompt = num_samples * [prompt]
106
prompt_ids = pipeline.prepare_inputs(prompt)
107
108
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
109
110
# shard inputs and rng
111
params = replicate(params)
112
prng_seed = jax.random.split(prng_seed, num_samples)
113
prompt_ids = shard(prompt_ids)
114
115
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
116
117
assert images.shape == (num_samples, 1, 512, 512, 3)
118
if jax.device_count() == 8:
119
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
120
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
121
122
def test_stable_diffusion_v1_4_bfloat_16(self):
123
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
124
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
125
)
126
127
prompt = (
128
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
129
" field, close up, split lighting, cinematic"
130
)
131
132
prng_seed = jax.random.PRNGKey(0)
133
num_inference_steps = 50
134
135
num_samples = jax.device_count()
136
prompt = num_samples * [prompt]
137
prompt_ids = pipeline.prepare_inputs(prompt)
138
139
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
140
141
# shard inputs and rng
142
params = replicate(params)
143
prng_seed = jax.random.split(prng_seed, num_samples)
144
prompt_ids = shard(prompt_ids)
145
146
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
147
148
assert images.shape == (num_samples, 1, 512, 512, 3)
149
if jax.device_count() == 8:
150
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
151
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
152
153
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
154
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
155
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
156
)
157
158
prompt = (
159
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
160
" field, close up, split lighting, cinematic"
161
)
162
163
prng_seed = jax.random.PRNGKey(0)
164
num_inference_steps = 50
165
166
num_samples = jax.device_count()
167
prompt = num_samples * [prompt]
168
prompt_ids = pipeline.prepare_inputs(prompt)
169
170
# shard inputs and rng
171
params = replicate(params)
172
prng_seed = jax.random.split(prng_seed, num_samples)
173
prompt_ids = shard(prompt_ids)
174
175
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
176
177
assert images.shape == (num_samples, 1, 512, 512, 3)
178
if jax.device_count() == 8:
179
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
180
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
181
182
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
183
scheduler = FlaxDDIMScheduler(
184
beta_start=0.00085,
185
beta_end=0.012,
186
beta_schedule="scaled_linear",
187
set_alpha_to_one=False,
188
steps_offset=1,
189
)
190
191
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
192
"CompVis/stable-diffusion-v1-4",
193
revision="bf16",
194
dtype=jnp.bfloat16,
195
scheduler=scheduler,
196
safety_checker=None,
197
)
198
scheduler_state = scheduler.create_state()
199
200
params["scheduler"] = scheduler_state
201
202
prompt = (
203
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
204
" field, close up, split lighting, cinematic"
205
)
206
207
prng_seed = jax.random.PRNGKey(0)
208
num_inference_steps = 50
209
210
num_samples = jax.device_count()
211
prompt = num_samples * [prompt]
212
prompt_ids = pipeline.prepare_inputs(prompt)
213
214
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
215
216
# shard inputs and rng
217
params = replicate(params)
218
prng_seed = jax.random.split(prng_seed, num_samples)
219
prompt_ids = shard(prompt_ids)
220
221
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
222
223
assert images.shape == (num_samples, 1, 512, 512, 3)
224
if jax.device_count() == 8:
225
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
226
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
227
228