Path: blob/master/guides/keras_hub/stable_diffusion_3_in_keras_hub.py
3293 views
"""1Title: Stable Diffusion 3 in KerasHub!2Author: [Hongyu Chiu](https://github.com/james77777778), [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml), [divamgupta](https://github.com/divamgupta)3Date created: 2024/10/094Last modified: 2024/10/245Description: Image generation using KerasHub's Stable Diffusion 3 model.6Accelerator: GPU7"""89"""10## Overview1112Stable Diffusion 3 is a powerful, open-source latent diffusion model (LDM)13designed to generate high-quality novel images based on text prompts. Released14by [Stability AI](https://stability.ai/), it was pre-trained on 1 billion15images and fine-tuned on 33 million high-quality aesthetic and preference images16, resulting in a greatly improved performance compared to previous version of17Stable Diffusion models.1819In this guide, we will explore KerasHub's implementation of the20[Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium)21including text-to-image, image-to-image and inpaint tasks.2223To get started, let's install a few dependencies and get images for our demo:24"""2526"""shell27!pip install -Uq keras28!pip install -Uq git+https://github.com/keras-team/keras-hub.git29!wget -O mountain_dog.png https://raw.githubusercontent.com/keras-team/keras-io/master/guides/img/stable_diffusion_3_in_keras_hub/mountain_dog.png30!wget -O mountain_dog_mask.png https://raw.githubusercontent.com/keras-team/keras-io/master/guides/img/stable_diffusion_3_in_keras_hub/mountain_dog_mask.png31"""3233import os3435os.environ["KERAS_BACKEND"] = "jax"3637import time3839import keras40import keras_hub41import matplotlib.pyplot as plt42import numpy as np43from PIL import Image4445"""46## Introduction4748Before diving into how latent diffusion models work, let's start by generating49some images using KerasHub's APIs.5051To avoid reinitializing variables for different tasks, we'll instantiate and52load the trained `backbone` and `preprocessor` using KerasHub's `from_preset`53factory method. If you only want to perform one task at a time, you can use a54simpler API like this:5556```python57text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(58"stable_diffusion_3_medium", dtype="float16"59)60```6162That will automatically load and configure trained `backbone` and `preprocessor`63for you.6465Note that in this guide, we'll use `image_shape=(512, 512, 3)` for faster66image generation. For higher-quality output, it's recommended to use the default67size of `1024`. Since the entire backbone has about 3 billion parameters, which68can be challenging to fit into a consumer-level GPU, we set `dtype="float16"` to69reduce the usage of GPU memory -- the officially released weights are also in70float16.7172It is also worth noting that the preset "stable_diffusion_3_medium" excludes the73T5XXL text encoder, as it requires significantly more GPU memory. The performace74degradation is negligible in most cases. The weights, including T5XXL, will be75available on KerasHub soon.76"""777879def display_generated_images(images):80"""Helper function to display the images from the inputs.8182This function accepts the following input formats:83- 3D numpy array.84- 4D numpy array: concatenated horizontally.85- List of 3D numpy arrays: concatenated horizontally.86"""87display_image = None88if isinstance(images, np.ndarray):89if images.ndim == 3:90display_image = Image.fromarray(images)91elif images.ndim == 4:92concated_images = np.concatenate(list(images), axis=1)93display_image = Image.fromarray(concated_images)94elif isinstance(images, list):95concated_images = np.concatenate(images, axis=1)96display_image = Image.fromarray(concated_images)9798if display_image is None:99raise ValueError("Unsupported input format.")100101plt.figure(figsize=(10, 10))102plt.axis("off")103plt.imshow(display_image)104plt.show()105plt.close()106107108backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(109"stable_diffusion_3_medium", image_shape=(512, 512, 3), dtype="float16"110)111preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(112"stable_diffusion_3_medium"113)114text_to_image = keras_hub.models.StableDiffusion3TextToImage(backbone, preprocessor)115116"""117Next, we give it a prompt:118"""119120prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"121122# When using JAX or TensorFlow backends, you might experience a significant123# compilation time during the first `generate()` call. The subsequent124# `generate()` call speedup highlights the power of JIT compilation and caching125# in frameworks like JAX and TensorFlow, making them well-suited for126# high-performance deep learning tasks like image generation.127generated_image = text_to_image.generate(prompt)128display_generated_images(generated_image)129130131"""132Pretty impressive! But how does this work?133134Let's dig into what "latent diffusion model" means.135136Consider the concept of "super-resolution," where a deep learning model137"denoises" an input image, turning it into a higher-resolution version. The138model uses its training data distribution to hallucinate the visual details that139are most likely given the input. To learn more about super-resolution, you can140check out the following Keras.io tutorials:141142- [Image Super-Resolution using an Efficient Sub-Pixel CNN](https://keras.io/examples/vision/super_resolution_sub_pixel/)143- [Enhanced Deep Residual Networks for single-image super-resolution](https://keras.io/examples/vision/edsr/)144145146147When we push this idea to the limit, we may start asking -- what if we just run148such a model on pure noise? The model would then "denoise the noise" and start149hallucinating a brand new image. By repeating the process multiple times, we150can get turn a small patch of noise into an increasingly clear and151high-resolution artificial picture.152153This is the key idea of latent diffusion, proposed in154[High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752).155To understand diffusion in depth, you can check the Keras.io tutorial156[Denoising Diffusion Implicit Models](https://keras.io/examples/generative/ddim/).157158159160To transition from latent diffusion to a text-to-image system, one key feature161must be added: the ability to control the generated visual content using prompt162keywords. In Stable Diffusion 3, the text encoders from the CLIP and T5XXL163models are used to obtain text embeddings, which are then fed into the diffusion164model to condition the diffusion process. This approach is based on the concept165of "classifier-free guidance", proposed in166[Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).167168When we combine these ideas, we get a high-level overview of the architecture of169Stable Diffusion 3:170171- Text encoders: Convert the text prompt into text embeddings.172- Diffusion model: Repeatedly "denoises" a smaller latent image patch.173- Decoder: Transforms the final latent patch into a higher-resolution image.174175First, the text prompt is projected into the latent space by multiple text176encoders, which are pretrained and frozen language models. Next, the text177embeddings, along with a randomly generated noise patch (typically from a178Gaussian distribution), are then fed into the diffusion model. The diffusion179model repeatly "denoises" the noise patch over a series of steps (the more180steps, the clearer and more refined the image becomes -- the default value is18128 steps). Finally, the latent patch is passed through the decoder from the VAE182model to render the image in high resolution.183184The overview of the Stable Diffusion 3 architecture:185186187This relatively simple system starts looking like magic once we train on188billions of pictures and their captions. As Feynman said about the universe:189_"It's not complicated, it's just a lot of it!"_190"""191192193"""194## Text-to-image task195196Now we know the basis of the Stable Diffusion 3 and the text-to-image task.197Let's explore further using KerasHub APIs.198199To use KerasHub's APIs for efficient batch processing, we can provide the model200with a list of prompts:201"""202203204generated_images = text_to_image.generate([prompt] * 3)205display_generated_images(generated_images)206207"""208The `num_steps` parameter controls the number of denoising steps used during209image generation. Increasing the number of steps typically leads to higher210quality images at the expense of increased generation time. In211Stable Diffusion 3, this parameter defaults to `28`.212"""213214num_steps = [10, 28, 50]215generated_images = []216for n in num_steps:217st = time.time()218generated_images.append(text_to_image.generate(prompt, num_steps=n))219print(f"Cost time (`num_steps={n}`): {time.time() - st:.2f}s")220221display_generated_images(generated_images)222223"""224We can use `"negative_prompts"` to guide the model away from generating specific225styles and elements. The input format becomes a dict with the keys `"prompts"`226and `"negative_prompts"`.227228If `"negative_prompts"` is not provided, it will be interpreted as an229unconditioned prompt with the default value of `""`.230"""231232generated_images = text_to_image.generate(233{234"prompts": [prompt] * 3,235"negative_prompts": ["Green color"] * 3,236}237)238display_generated_images(generated_images)239240"""241`guidance_scale` affects how much the `"prompts"` influences image generation.242A lower value gives the model creativity to generate images that are more243loosely related to the prompt. Higher values push the model to follow the prompt244more closely. If this value is too high, you may observe some artifacts in the245generated image. In Stable Diffusion 3, it defaults to `7.0`.246"""247248generated_images = [249text_to_image.generate(prompt, guidance_scale=2.5),250text_to_image.generate(prompt, guidance_scale=7.0),251text_to_image.generate(prompt, guidance_scale=10.5),252]253display_generated_images(generated_images)254255"""256Note that `negative_prompts` and `guidance_scale` are related. The formula in257the implementation can be represented as follows:258`predicted_noise = negative_noise + guidance_scale * (positive_noise - negative_noise)`.259"""260261"""262## Image-to-image task263264A reference image can be used as a starting point for the diffusion process.265This requires an additional module in the pipeline: the encoder from the VAE266model.267268The reference image is encoded by the VAE encoder into the latent space, where269noise is then added. The subsequent denoising steps follow the same procedure as270the text-to-image task.271272The input format becomes a dict with the keys `"images"`, `"prompts"` and273optionally `"negative_prompts"`.274"""275276image_to_image = keras_hub.models.StableDiffusion3ImageToImage(backbone, preprocessor)277278image = Image.open("mountain_dog.png").convert("RGB")279image = image.resize((512, 512))280width, height = image.size281282# Note that the values of the image must be in the range of [-1.0, 1.0].283rescale = keras.layers.Rescaling(scale=1 / 127.5, offset=-1.0)284image_array = rescale(np.array(image))285286prompt = "dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, "287prompt += "adorable, Pixar, Disney, 8k"288289generated_image = image_to_image.generate(290{291"images": image_array,292"prompts": prompt,293}294)295display_generated_images(296[297np.array(image),298generated_image,299]300)301302"""303As you can see, a new image is generated based on the reference image and the304prompt.305"""306307"""308The `strength` parameter plays a key role in determining how closely the309generated image resembles the reference image. The value ranges from310`[0.0, 1.0]` and defaults to `0.8` in Stable Diffusion 3.311312A higher `strength` value gives the model more “creativity” to generate an image313that is different from the reference image. At a value of `1.0`, the reference314image is completely ignored, making the task purely text-to-image.315316A lower `strength` value means the generated image is more similar to the317reference image.318"""319320generated_images = [321image_to_image.generate(322{323"images": image_array,324"prompts": prompt,325},326strength=0.7,327),328image_to_image.generate(329{330"images": image_array,331"prompts": prompt,332},333strength=0.8,334),335image_to_image.generate(336{337"images": image_array,338"prompts": prompt,339},340strength=0.9,341),342]343display_generated_images(generated_images)344345"""346## Inpaint task347348Building upon the image-to-image task, we can also control the generated area349using a mask. This process is called inpainting, where specific areas of an350image are replaced or edited.351352Inpainting relies on a mask to determine which regions of the image to modify.353The areas to inpaint are represented by white pixels (`True`), while the areas354to preserve are represented by black pixels (`False`).355356For inpainting, the input is a dict with the keys `"images"`, `"masks"`,357`"prompts"` and optionally `"negative_prompts"`.358"""359360inpaint = keras_hub.models.StableDiffusion3Inpaint(backbone, preprocessor)361362image = Image.open("mountain_dog.png").convert("RGB")363image = image.resize((512, 512))364image_array = rescale(np.array(image))365366# Note that the mask values are of boolean dtype.367mask = Image.open("mountain_dog_mask.png").convert("L")368mask = mask.resize((512, 512))369mask_array = np.array(mask).astype("bool")370371prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly "372prompt += "detailed, 8k"373374generated_image = inpaint.generate(375{376"images": image_array,377"masks": mask_array,378"prompts": prompt,379}380)381display_generated_images(382[383np.array(image),384np.array(mask.convert("RGB")),385generated_image,386]387)388389"""390Fantastic! The dog is replaced by a cute black cat, but unlike image-to-image,391the background is preserved.392393Note that inpainting task also includes `strength` parameter to control the394image generation, with the default value of `0.6` in Stable Diffusion 3.395"""396397"""398## Conclusion399400KerasHub's `StableDiffusion3` supports a variety of applications and, with the401help of Keras 3, enables running the model on TensorFlow, JAX, and PyTorch!402"""403404405