Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/stable_diffusion_3_in_keras_hub.py
3293 views
1
"""
2
Title: Stable Diffusion 3 in KerasHub!
3
Author: [Hongyu Chiu](https://github.com/james77777778), [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml), [divamgupta](https://github.com/divamgupta)
4
Date created: 2024/10/09
5
Last modified: 2024/10/24
6
Description: Image generation using KerasHub's Stable Diffusion 3 model.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
Stable Diffusion 3 is a powerful, open-source latent diffusion model (LDM)
14
designed to generate high-quality novel images based on text prompts. Released
15
by [Stability AI](https://stability.ai/), it was pre-trained on 1 billion
16
images and fine-tuned on 33 million high-quality aesthetic and preference images
17
, resulting in a greatly improved performance compared to previous version of
18
Stable Diffusion models.
19
20
In this guide, we will explore KerasHub's implementation of the
21
[Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
22
including text-to-image, image-to-image and inpaint tasks.
23
24
To get started, let's install a few dependencies and get images for our demo:
25
"""
26
27
"""shell
28
!pip install -Uq keras
29
!pip install -Uq git+https://github.com/keras-team/keras-hub.git
30
!wget -O mountain_dog.png https://raw.githubusercontent.com/keras-team/keras-io/master/guides/img/stable_diffusion_3_in_keras_hub/mountain_dog.png
31
!wget -O mountain_dog_mask.png https://raw.githubusercontent.com/keras-team/keras-io/master/guides/img/stable_diffusion_3_in_keras_hub/mountain_dog_mask.png
32
"""
33
34
import os
35
36
os.environ["KERAS_BACKEND"] = "jax"
37
38
import time
39
40
import keras
41
import keras_hub
42
import matplotlib.pyplot as plt
43
import numpy as np
44
from PIL import Image
45
46
"""
47
## Introduction
48
49
Before diving into how latent diffusion models work, let's start by generating
50
some images using KerasHub's APIs.
51
52
To avoid reinitializing variables for different tasks, we'll instantiate and
53
load the trained `backbone` and `preprocessor` using KerasHub's `from_preset`
54
factory method. If you only want to perform one task at a time, you can use a
55
simpler API like this:
56
57
```python
58
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
59
"stable_diffusion_3_medium", dtype="float16"
60
)
61
```
62
63
That will automatically load and configure trained `backbone` and `preprocessor`
64
for you.
65
66
Note that in this guide, we'll use `image_shape=(512, 512, 3)` for faster
67
image generation. For higher-quality output, it's recommended to use the default
68
size of `1024`. Since the entire backbone has about 3 billion parameters, which
69
can be challenging to fit into a consumer-level GPU, we set `dtype="float16"` to
70
reduce the usage of GPU memory -- the officially released weights are also in
71
float16.
72
73
It is also worth noting that the preset "stable_diffusion_3_medium" excludes the
74
T5XXL text encoder, as it requires significantly more GPU memory. The performace
75
degradation is negligible in most cases. The weights, including T5XXL, will be
76
available on KerasHub soon.
77
"""
78
79
80
def display_generated_images(images):
81
"""Helper function to display the images from the inputs.
82
83
This function accepts the following input formats:
84
- 3D numpy array.
85
- 4D numpy array: concatenated horizontally.
86
- List of 3D numpy arrays: concatenated horizontally.
87
"""
88
display_image = None
89
if isinstance(images, np.ndarray):
90
if images.ndim == 3:
91
display_image = Image.fromarray(images)
92
elif images.ndim == 4:
93
concated_images = np.concatenate(list(images), axis=1)
94
display_image = Image.fromarray(concated_images)
95
elif isinstance(images, list):
96
concated_images = np.concatenate(images, axis=1)
97
display_image = Image.fromarray(concated_images)
98
99
if display_image is None:
100
raise ValueError("Unsupported input format.")
101
102
plt.figure(figsize=(10, 10))
103
plt.axis("off")
104
plt.imshow(display_image)
105
plt.show()
106
plt.close()
107
108
109
backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
110
"stable_diffusion_3_medium", image_shape=(512, 512, 3), dtype="float16"
111
)
112
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
113
"stable_diffusion_3_medium"
114
)
115
text_to_image = keras_hub.models.StableDiffusion3TextToImage(backbone, preprocessor)
116
117
"""
118
Next, we give it a prompt:
119
"""
120
121
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
122
123
# When using JAX or TensorFlow backends, you might experience a significant
124
# compilation time during the first `generate()` call. The subsequent
125
# `generate()` call speedup highlights the power of JIT compilation and caching
126
# in frameworks like JAX and TensorFlow, making them well-suited for
127
# high-performance deep learning tasks like image generation.
128
generated_image = text_to_image.generate(prompt)
129
display_generated_images(generated_image)
130
131
132
"""
133
Pretty impressive! But how does this work?
134
135
Let's dig into what "latent diffusion model" means.
136
137
Consider the concept of "super-resolution," where a deep learning model
138
"denoises" an input image, turning it into a higher-resolution version. The
139
model uses its training data distribution to hallucinate the visual details that
140
are most likely given the input. To learn more about super-resolution, you can
141
check out the following Keras.io tutorials:
142
143
- [Image Super-Resolution using an Efficient Sub-Pixel CNN](https://keras.io/examples/vision/super_resolution_sub_pixel/)
144
- [Enhanced Deep Residual Networks for single-image super-resolution](https://keras.io/examples/vision/edsr/)
145
146
![Super-resolution](https://i.imgur.com/M0XdqOo.png)
147
148
When we push this idea to the limit, we may start asking -- what if we just run
149
such a model on pure noise? The model would then "denoise the noise" and start
150
hallucinating a brand new image. By repeating the process multiple times, we
151
can get turn a small patch of noise into an increasingly clear and
152
high-resolution artificial picture.
153
154
This is the key idea of latent diffusion, proposed in
155
[High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752).
156
To understand diffusion in depth, you can check the Keras.io tutorial
157
[Denoising Diffusion Implicit Models](https://keras.io/examples/generative/ddim/).
158
159
![Denoising diffusion](https://i.imgur.com/FSCKtZq.gif)
160
161
To transition from latent diffusion to a text-to-image system, one key feature
162
must be added: the ability to control the generated visual content using prompt
163
keywords. In Stable Diffusion 3, the text encoders from the CLIP and T5XXL
164
models are used to obtain text embeddings, which are then fed into the diffusion
165
model to condition the diffusion process. This approach is based on the concept
166
of "classifier-free guidance", proposed in
167
[Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
168
169
When we combine these ideas, we get a high-level overview of the architecture of
170
Stable Diffusion 3:
171
172
- Text encoders: Convert the text prompt into text embeddings.
173
- Diffusion model: Repeatedly "denoises" a smaller latent image patch.
174
- Decoder: Transforms the final latent patch into a higher-resolution image.
175
176
First, the text prompt is projected into the latent space by multiple text
177
encoders, which are pretrained and frozen language models. Next, the text
178
embeddings, along with a randomly generated noise patch (typically from a
179
Gaussian distribution), are then fed into the diffusion model. The diffusion
180
model repeatly "denoises" the noise patch over a series of steps (the more
181
steps, the clearer and more refined the image becomes -- the default value is
182
28 steps). Finally, the latent patch is passed through the decoder from the VAE
183
model to render the image in high resolution.
184
185
The overview of the Stable Diffusion 3 architecture:
186
![The Stable Diffusion 3 architecture](https://i.imgur.com/D9y0fWF.png)
187
188
This relatively simple system starts looking like magic once we train on
189
billions of pictures and their captions. As Feynman said about the universe:
190
_"It's not complicated, it's just a lot of it!"_
191
"""
192
193
194
"""
195
## Text-to-image task
196
197
Now we know the basis of the Stable Diffusion 3 and the text-to-image task.
198
Let's explore further using KerasHub APIs.
199
200
To use KerasHub's APIs for efficient batch processing, we can provide the model
201
with a list of prompts:
202
"""
203
204
205
generated_images = text_to_image.generate([prompt] * 3)
206
display_generated_images(generated_images)
207
208
"""
209
The `num_steps` parameter controls the number of denoising steps used during
210
image generation. Increasing the number of steps typically leads to higher
211
quality images at the expense of increased generation time. In
212
Stable Diffusion 3, this parameter defaults to `28`.
213
"""
214
215
num_steps = [10, 28, 50]
216
generated_images = []
217
for n in num_steps:
218
st = time.time()
219
generated_images.append(text_to_image.generate(prompt, num_steps=n))
220
print(f"Cost time (`num_steps={n}`): {time.time() - st:.2f}s")
221
222
display_generated_images(generated_images)
223
224
"""
225
We can use `"negative_prompts"` to guide the model away from generating specific
226
styles and elements. The input format becomes a dict with the keys `"prompts"`
227
and `"negative_prompts"`.
228
229
If `"negative_prompts"` is not provided, it will be interpreted as an
230
unconditioned prompt with the default value of `""`.
231
"""
232
233
generated_images = text_to_image.generate(
234
{
235
"prompts": [prompt] * 3,
236
"negative_prompts": ["Green color"] * 3,
237
}
238
)
239
display_generated_images(generated_images)
240
241
"""
242
`guidance_scale` affects how much the `"prompts"` influences image generation.
243
A lower value gives the model creativity to generate images that are more
244
loosely related to the prompt. Higher values push the model to follow the prompt
245
more closely. If this value is too high, you may observe some artifacts in the
246
generated image. In Stable Diffusion 3, it defaults to `7.0`.
247
"""
248
249
generated_images = [
250
text_to_image.generate(prompt, guidance_scale=2.5),
251
text_to_image.generate(prompt, guidance_scale=7.0),
252
text_to_image.generate(prompt, guidance_scale=10.5),
253
]
254
display_generated_images(generated_images)
255
256
"""
257
Note that `negative_prompts` and `guidance_scale` are related. The formula in
258
the implementation can be represented as follows:
259
`predicted_noise = negative_noise + guidance_scale * (positive_noise - negative_noise)`.
260
"""
261
262
"""
263
## Image-to-image task
264
265
A reference image can be used as a starting point for the diffusion process.
266
This requires an additional module in the pipeline: the encoder from the VAE
267
model.
268
269
The reference image is encoded by the VAE encoder into the latent space, where
270
noise is then added. The subsequent denoising steps follow the same procedure as
271
the text-to-image task.
272
273
The input format becomes a dict with the keys `"images"`, `"prompts"` and
274
optionally `"negative_prompts"`.
275
"""
276
277
image_to_image = keras_hub.models.StableDiffusion3ImageToImage(backbone, preprocessor)
278
279
image = Image.open("mountain_dog.png").convert("RGB")
280
image = image.resize((512, 512))
281
width, height = image.size
282
283
# Note that the values of the image must be in the range of [-1.0, 1.0].
284
rescale = keras.layers.Rescaling(scale=1 / 127.5, offset=-1.0)
285
image_array = rescale(np.array(image))
286
287
prompt = "dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, "
288
prompt += "adorable, Pixar, Disney, 8k"
289
290
generated_image = image_to_image.generate(
291
{
292
"images": image_array,
293
"prompts": prompt,
294
}
295
)
296
display_generated_images(
297
[
298
np.array(image),
299
generated_image,
300
]
301
)
302
303
"""
304
As you can see, a new image is generated based on the reference image and the
305
prompt.
306
"""
307
308
"""
309
The `strength` parameter plays a key role in determining how closely the
310
generated image resembles the reference image. The value ranges from
311
`[0.0, 1.0]` and defaults to `0.8` in Stable Diffusion 3.
312
313
A higher `strength` value gives the model more “creativity” to generate an image
314
that is different from the reference image. At a value of `1.0`, the reference
315
image is completely ignored, making the task purely text-to-image.
316
317
A lower `strength` value means the generated image is more similar to the
318
reference image.
319
"""
320
321
generated_images = [
322
image_to_image.generate(
323
{
324
"images": image_array,
325
"prompts": prompt,
326
},
327
strength=0.7,
328
),
329
image_to_image.generate(
330
{
331
"images": image_array,
332
"prompts": prompt,
333
},
334
strength=0.8,
335
),
336
image_to_image.generate(
337
{
338
"images": image_array,
339
"prompts": prompt,
340
},
341
strength=0.9,
342
),
343
]
344
display_generated_images(generated_images)
345
346
"""
347
## Inpaint task
348
349
Building upon the image-to-image task, we can also control the generated area
350
using a mask. This process is called inpainting, where specific areas of an
351
image are replaced or edited.
352
353
Inpainting relies on a mask to determine which regions of the image to modify.
354
The areas to inpaint are represented by white pixels (`True`), while the areas
355
to preserve are represented by black pixels (`False`).
356
357
For inpainting, the input is a dict with the keys `"images"`, `"masks"`,
358
`"prompts"` and optionally `"negative_prompts"`.
359
"""
360
361
inpaint = keras_hub.models.StableDiffusion3Inpaint(backbone, preprocessor)
362
363
image = Image.open("mountain_dog.png").convert("RGB")
364
image = image.resize((512, 512))
365
image_array = rescale(np.array(image))
366
367
# Note that the mask values are of boolean dtype.
368
mask = Image.open("mountain_dog_mask.png").convert("L")
369
mask = mask.resize((512, 512))
370
mask_array = np.array(mask).astype("bool")
371
372
prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly "
373
prompt += "detailed, 8k"
374
375
generated_image = inpaint.generate(
376
{
377
"images": image_array,
378
"masks": mask_array,
379
"prompts": prompt,
380
}
381
)
382
display_generated_images(
383
[
384
np.array(image),
385
np.array(mask.convert("RGB")),
386
generated_image,
387
]
388
)
389
390
"""
391
Fantastic! The dog is replaced by a cute black cat, but unlike image-to-image,
392
the background is preserved.
393
394
Note that inpainting task also includes `strength` parameter to control the
395
image generation, with the default value of `0.6` in Stable Diffusion 3.
396
"""
397
398
"""
399
## Conclusion
400
401
KerasHub's `StableDiffusion3` supports a variety of applications and, with the
402
help of Keras 3, enables running the model on TensorFlow, JAX, and PyTorch!
403
"""
404
405