CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

Fine-tuning Stable Diffusion XL with DreamBooth and LoRA on a free-tier Colab Notebook 🧨

In this notebook, we show how to fine-tune Stable Diffusion XL (SDXL) with DreamBooth and LoRA on a T4 GPU.

SDXL consists of a much larger UNet and two text encoders that make the cross-attention context quite larger than the previous variants.

So, to pull this off, we will make use of several tricks such as gradient checkpointing, mixed-precision, and 8-bit Adam. So, hang tight and let's get started 🧪

Setup 🪓

# Check the GPU !nvidia-smi
Thu Nov 23 06:47:16 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 49C P8 9W / 70W | 0MiB / 15360MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
# Install dependencies. !pip install bitsandbytes transformers accelerate peft -q
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.8/211.8 MB 2.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.6/92.6 MB 9.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 261.4/261.4 kB 21.5 MB/s eta 0:00:00

Make sure to install diffusers from main.

!pip install git+https://github.com/huggingface/diffusers.git -q
Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Building wheel for diffusers (pyproject.toml) ... done

Download diffusers SDXL DreamBooth training script.

!wget https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
--2023-11-23 06:48:12-- https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth_lora_sdxl.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 72845 (71K) [text/plain] Saving to: ‘train_dreambooth_lora_sdxl.py’ train_dreambooth_lo 100%[===================>] 71.14K --.-KB/s in 0.001s 2023-11-23 06:48:12 (46.6 MB/s) - ‘train_dreambooth_lora_sdxl.py’ saved [72845/72845]

Dataset 🐶

Let's get our training data! For this example, we'll download some images from the hub

If you already have a dataset on the hub you wish to use, you can skip this part and go straight to: "Prep for training 💻" section, where you'll simply specify the dataset name.

If your images are saved locally, and/or you want to add BLIP generated captions, pick option 1 or 2 below.

Option 1: upload example images from your local files:

import os from google.colab import files # pick a name for the image folder local_dir = "./dog/" #@param os.makedirs(local_dir) os.chdir(local_dir) # choose and upload local images into the newly created directory uploaded_images = files.upload() os.chdir("/content") # back to parent directory

Option 2: download example images from the hub:

from huggingface_hub import snapshot_download local_dir = "./dog/" snapshot_download( "diffusers/dog-example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes", )
'/content/dog'

Preview the images:

from PIL import Image def image_grid(imgs, rows, cols, resize=256): if resize is not None: imgs = [img.resize((resize, resize)) for img in imgs] w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid
import glob # change path to display images from your local dir img_paths = "./dog/*.jpeg" imgs = [Image.open(path) for path in glob.glob(img_paths)] num_imgs_to_preview = 5 image_grid(imgs[:num_imgs_to_preview], 1, num_imgs_to_preview)
Image in a Jupyter notebook

Generate custom captions with BLIP

Load BLIP to auto caption your images:

import requests from transformers import AutoProcessor, BlipForConditionalGeneration import torch device = "cuda" if torch.cuda.is_available() else "cpu" # load the processor and the captioning model blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device) # captioning utility def caption_images(input_image): inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16) pixel_values = inputs.pixel_values generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50) generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_caption
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
import glob from PIL import Image # create a list of (Pil.Image, path) pairs local_dir = "./dog/" imgs_and_paths = [(path,Image.open(path)) for path in glob.glob(f"{local_dir}*.jpeg")]

Now let's add the concept token identifier (e.g. TOK) to each caption using a caption prefix. Feel free to change the prefix according to the concept you're training on!

  • for this example we can use "a photo of TOK," other options include:

    • For styles - "In the style of TOK"

    • For faces - "photo of a TOK person"

  • You can add additional identifiers to the prefix that can help steer the model in the right direction. -- e.g. for this example, instead of "a photo of TOK" we can use "a photo of TOK dog" / "a photo of TOK corgi dog"

import json caption_prefix = "a photo of TOK dog, " #@param with open(f'{local_dir}metadata.jsonl', 'w') as outfile: for img in imgs_and_paths: caption = caption_prefix + caption_images(img[1]).split("\n")[0] entry = {"file_name":img[0].split("/")[-1], "prompt": caption} json.dump(entry, outfile) outfile.write('\n')

Free some memory:

import gc # delete the BLIP pipelines and free up some memory del blip_processor, blip_model gc.collect() torch.cuda.empty_cache()

Prep for training 💻

Initialize accelerate:

import locale locale.getpreferredencoding = lambda: "UTF-8" !accelerate config default
accelerate configuration saved at /root/.cache/huggingface/accelerate/default_config.yaml

Log into your Hugging Face account

Pass your write access token so that we can push the trained checkpoints to the Hugging Face Hub:

from huggingface_hub import notebook_login notebook_login()

Train! 🔬

Set Hyperparameters ⚡

To ensure we can DreamBooth with LoRA on a heavy pipeline like Stable Diffusion XL, we're using:

  • Gradient checkpointing (--gradient_accumulation_steps)

  • 8-bit Adam (--use_8bit_adam)

  • Mixed-precision training (--mixed-precision="fp16")

Launch training 🚀🚀🚀

To allow for custom captions we need to install the datasets library, you can skip that if you want to train solely with --instance_prompt. In that case, specify --instance_data_dir instead of --dataset_name

!pip install datasets -q
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.6/485.6 kB 7.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 110.5/110.5 kB 13.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 16.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.3/134.3 kB 15.4 MB/s eta 0:00:00
  • Use --output_dir to specify your LoRA model repository name!

  • Use --caption_column to specify name of the cpation column in your dataset. In this example we used "prompt" to save our captions in the metadata file, change this according to your needs.

#!/usr/bin/env bash !accelerate launch train_dreambooth_lora_sdxl.py \ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \ --dataset_name="dog" \ --output_dir="corgy_dog_LoRA" \ --caption_column="prompt"\ --mixed_precision="fp16" \ --instance_prompt="a photo of TOK dog" \ --resolution=1024 \ --train_batch_size=1 \ --gradient_accumulation_steps=3 \ --gradient_checkpointing \ --learning_rate=1e-4 \ --snr_gamma=5.0 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --mixed_precision="fp16" \ --use_8bit_adam \ --max_train_steps=500 \ --checkpointing_steps=717 \ --seed="0"
WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for: PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118) Python 3.10.13 (you have 3.10.12) Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers) Memory-efficient attention, SwiGLU, sparse and more won't be available. Set XFORMERS_MORE_DETAILS=1 for more details 2023-11-23 07:06:49.633870: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-23 07:06:49.633948: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-23 07:06:49.638631: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-23 07:06:52.427754: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 11/23/2023 07:06:55 - INFO - __main__ - Distributed environment: NO Num processes: 1 Process index: 0 Local process index: 0 Device: cuda Mixed precision type: fp16 You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors. You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors. {'thresholding', 'dynamic_thresholding_ratio', 'clip_sample_range', 'variance_type'} was not found in config. Values will be initialized to default values. {'attention_type', 'reverse_transformer_layers_per_block', 'dropout'} was not found in config. Values will be initialized to default values. 11/23/2023 07:08:28 - WARNING - datasets.builder - Found cached dataset imagefolder (/root/.cache/huggingface/datasets/imagefolder/dog-3c3c059549bb4011/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f) 100% 1/1 [00:00<00:00, 17.73it/s] 11/23/2023 07:08:29 - INFO - __main__ - ***** Running training ***** 11/23/2023 07:08:29 - INFO - __main__ - Num examples = 5 11/23/2023 07:08:29 - INFO - __main__ - Num batches each epoch = 5 11/23/2023 07:08:29 - INFO - __main__ - Num Epochs = 250 11/23/2023 07:08:29 - INFO - __main__ - Instantaneous batch size per device = 1 11/23/2023 07:08:29 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 3 11/23/2023 07:08:29 - INFO - __main__ - Gradient Accumulation steps = 3 11/23/2023 07:08:29 - INFO - __main__ - Total optimization steps = 500 Steps: 100% 500/500 [1:08:02<00:00, 7.96s/it, loss=0.00515, lr=0.0001]Model weights saved in corgy_dog_LoRA/pytorch_lora_weights.safetensors {'image_encoder', 'feature_extractor'} was not found in config. Values will be initialized to default values. Loading pipeline components...: 0% 0/7 [00:00<?, ?it/s]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 14% 1/7 [00:00<00:00, 9.49it/s]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loaded text_encoder_2 as CLIPTextModelWithProjection from `text_encoder_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 43% 3/7 [00:17<00:25, 6.34s/it]Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0. Loading pipeline components...: 57% 4/7 [00:17<00:12, 4.19s/it]Traceback (most recent call last): File "/usr/local/bin/accelerate", line 8, in <module> sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 47, in main args.func(args) File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 994, in launch_command simple_launcher(args) File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 636, in simple_launcher raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) subprocess.CalledProcessError: Command '['/usr/bin/python3', 'train_dreambooth_lora_sdxl.py', '--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0', '--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix', '--dataset_name=dog', '--output_dir=corgy_dog_LoRA', '--caption_column=prompt', '--mixed_precision=fp16', '--instance_prompt=a photo of TOK dog', '--resolution=1024', '--train_batch_size=1', '--gradient_accumulation_steps=3', '--gradient_checkpointing', '--learning_rate=1e-4', '--snr_gamma=5.0', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--mixed_precision=fp16', '--use_8bit_adam', '--max_train_steps=500', '--checkpointing_steps=717', '--seed=0', '--push_to_hub']' died with <Signals.SIGKILL: 9>.

Save your model to the hub and check it out 🔥

from huggingface_hub import whoami from pathlib import Path #@markdown make sure the `output_dir` you specify here is the same as the one used for training output_dir = "corgy_dog_LoRA" #@param username = whoami(token=Path("/root/.cache/huggingface/"))["name"] repo_id = f"{username}/{output_dir}"
# @markdown Sometimes training finishes succesfuly (i.e. a **.safetensores** file with the LoRA weights saved properly to your local `output_dir`) but there's not enough RAM in the free tier to push the model to the hub 🙁 # @markdown # @markdown To mitigate this, run this cell with your training arguments to make sure your model is uploaded! 🤗 # push to the hub🔥 from train_dreambooth_lora_sdxl import save_model_card from huggingface_hub import upload_folder, create_repo repo_id = create_repo(repo_id, exist_ok=True).repo_id # change the params below according to your training arguments save_model_card( repo_id = repo_id, images=[], base_model="stabilityai/stable-diffusion-xl-base-1.0", train_text_encoder=False, instance_prompt="a photo of TOK dog", validation_prompt=None, repo_folder=output_dir, vae_path="madebyollin/sdxl-vae-fp16-fix", ) upload_folder( repo_id=repo_id, folder_path=output_dir, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], )
WARNING:xformers:WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for: PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118) Python 3.10.13 (you have 3.10.12) Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers) Memory-efficient attention, SwiGLU, sparse and more won't be available. Set XFORMERS_MORE_DETAILS=1 for more details
'https://huggingface.co/LinoyTsaban/corgy_dog_LoRA/tree/main/'
from IPython.display import display, Markdown link_to_model = f"https://huggingface.co/{repo_id}" display(Markdown("### Your model has finished training.\nAccess it here: {}".format(link_to_model)))

Your model has finished training.

Access it here: https://huggingface.co/LinoyTsaban/corgy_dog_LoRA

Let's generate some images with it!

Inference 🐕

import torch from diffusers import DiffusionPipeline, AutoencoderKL vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) pipe.load_lora_weights(repo_id) _ = pipe.to("cuda")
prompt = "a photo of TOK dog in a bucket at the beach" # @param image = pipe(prompt=prompt, num_inference_steps=25).images[0] image
Image in a Jupyter notebook