Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/dreambooth/train_dreambooth_lora.py
1441 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
16
import argparse
17
import hashlib
18
import logging
19
import math
20
import os
21
import warnings
22
from pathlib import Path
23
from typing import Optional
24
25
import numpy as np
26
import torch
27
import torch.nn.functional as F
28
import torch.utils.checkpoint
29
import transformers
30
from accelerate import Accelerator
31
from accelerate.logging import get_logger
32
from accelerate.utils import ProjectConfiguration, set_seed
33
from huggingface_hub import HfFolder, Repository, create_repo, whoami
34
from packaging import version
35
from PIL import Image
36
from torch.utils.data import Dataset
37
from torchvision import transforms
38
from tqdm.auto import tqdm
39
from transformers import AutoTokenizer, PretrainedConfig
40
41
import diffusers
42
from diffusers import (
43
AutoencoderKL,
44
DDPMScheduler,
45
DiffusionPipeline,
46
DPMSolverMultistepScheduler,
47
UNet2DConditionModel,
48
)
49
from diffusers.loaders import AttnProcsLayers
50
from diffusers.models.attention_processor import LoRAAttnProcessor
51
from diffusers.optimization import get_scheduler
52
from diffusers.utils import check_min_version, is_wandb_available
53
from diffusers.utils.import_utils import is_xformers_available
54
55
56
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
57
check_min_version("0.15.0.dev0")
58
59
logger = get_logger(__name__)
60
61
62
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
63
img_str = ""
64
for i, image in enumerate(images):
65
image.save(os.path.join(repo_folder, f"image_{i}.png"))
66
img_str += f"![img_{i}](./image_{i}.png)\n"
67
68
yaml = f"""
69
---
70
license: creativeml-openrail-m
71
base_model: {base_model}
72
instance_prompt: {prompt}
73
tags:
74
- stable-diffusion
75
- stable-diffusion-diffusers
76
- text-to-image
77
- diffusers
78
- lora
79
inference: true
80
---
81
"""
82
model_card = f"""
83
# LoRA DreamBooth - {repo_name}
84
85
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
86
{img_str}
87
"""
88
with open(os.path.join(repo_folder, "README.md"), "w") as f:
89
f.write(yaml + model_card)
90
91
92
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
93
text_encoder_config = PretrainedConfig.from_pretrained(
94
pretrained_model_name_or_path,
95
subfolder="text_encoder",
96
revision=revision,
97
)
98
model_class = text_encoder_config.architectures[0]
99
100
if model_class == "CLIPTextModel":
101
from transformers import CLIPTextModel
102
103
return CLIPTextModel
104
elif model_class == "RobertaSeriesModelWithTransformation":
105
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
106
107
return RobertaSeriesModelWithTransformation
108
else:
109
raise ValueError(f"{model_class} is not supported.")
110
111
112
def parse_args(input_args=None):
113
parser = argparse.ArgumentParser(description="Simple example of a training script.")
114
parser.add_argument(
115
"--pretrained_model_name_or_path",
116
type=str,
117
default=None,
118
required=True,
119
help="Path to pretrained model or model identifier from huggingface.co/models.",
120
)
121
parser.add_argument(
122
"--revision",
123
type=str,
124
default=None,
125
required=False,
126
help="Revision of pretrained model identifier from huggingface.co/models.",
127
)
128
parser.add_argument(
129
"--tokenizer_name",
130
type=str,
131
default=None,
132
help="Pretrained tokenizer name or path if not the same as model_name",
133
)
134
parser.add_argument(
135
"--instance_data_dir",
136
type=str,
137
default=None,
138
required=True,
139
help="A folder containing the training data of instance images.",
140
)
141
parser.add_argument(
142
"--class_data_dir",
143
type=str,
144
default=None,
145
required=False,
146
help="A folder containing the training data of class images.",
147
)
148
parser.add_argument(
149
"--instance_prompt",
150
type=str,
151
default=None,
152
required=True,
153
help="The prompt with identifier specifying the instance",
154
)
155
parser.add_argument(
156
"--class_prompt",
157
type=str,
158
default=None,
159
help="The prompt to specify images in the same class as provided instance images.",
160
)
161
parser.add_argument(
162
"--validation_prompt",
163
type=str,
164
default=None,
165
help="A prompt that is used during validation to verify that the model is learning.",
166
)
167
parser.add_argument(
168
"--num_validation_images",
169
type=int,
170
default=4,
171
help="Number of images that should be generated during validation with `validation_prompt`.",
172
)
173
parser.add_argument(
174
"--validation_epochs",
175
type=int,
176
default=50,
177
help=(
178
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
179
" `args.validation_prompt` multiple times: `args.num_validation_images`."
180
),
181
)
182
parser.add_argument(
183
"--with_prior_preservation",
184
default=False,
185
action="store_true",
186
help="Flag to add prior preservation loss.",
187
)
188
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
189
parser.add_argument(
190
"--num_class_images",
191
type=int,
192
default=100,
193
help=(
194
"Minimal class images for prior preservation loss. If there are not enough images already present in"
195
" class_data_dir, additional images will be sampled with class_prompt."
196
),
197
)
198
parser.add_argument(
199
"--output_dir",
200
type=str,
201
default="lora-dreambooth-model",
202
help="The output directory where the model predictions and checkpoints will be written.",
203
)
204
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
205
parser.add_argument(
206
"--resolution",
207
type=int,
208
default=512,
209
help=(
210
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
211
" resolution"
212
),
213
)
214
parser.add_argument(
215
"--center_crop",
216
default=False,
217
action="store_true",
218
help=(
219
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
220
" cropped. The images will be resized to the resolution first before cropping."
221
),
222
)
223
parser.add_argument(
224
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
225
)
226
parser.add_argument(
227
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
228
)
229
parser.add_argument("--num_train_epochs", type=int, default=1)
230
parser.add_argument(
231
"--max_train_steps",
232
type=int,
233
default=None,
234
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
235
)
236
parser.add_argument(
237
"--checkpointing_steps",
238
type=int,
239
default=500,
240
help=(
241
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
242
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
243
" training using `--resume_from_checkpoint`."
244
),
245
)
246
parser.add_argument(
247
"--checkpoints_total_limit",
248
type=int,
249
default=None,
250
help=(
251
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
252
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
253
" for more docs"
254
),
255
)
256
parser.add_argument(
257
"--resume_from_checkpoint",
258
type=str,
259
default=None,
260
help=(
261
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
262
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
263
),
264
)
265
parser.add_argument(
266
"--gradient_accumulation_steps",
267
type=int,
268
default=1,
269
help="Number of updates steps to accumulate before performing a backward/update pass.",
270
)
271
parser.add_argument(
272
"--gradient_checkpointing",
273
action="store_true",
274
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
275
)
276
parser.add_argument(
277
"--learning_rate",
278
type=float,
279
default=5e-4,
280
help="Initial learning rate (after the potential warmup period) to use.",
281
)
282
parser.add_argument(
283
"--scale_lr",
284
action="store_true",
285
default=False,
286
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
287
)
288
parser.add_argument(
289
"--lr_scheduler",
290
type=str,
291
default="constant",
292
help=(
293
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
294
' "constant", "constant_with_warmup"]'
295
),
296
)
297
parser.add_argument(
298
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
299
)
300
parser.add_argument(
301
"--lr_num_cycles",
302
type=int,
303
default=1,
304
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
305
)
306
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
307
parser.add_argument(
308
"--dataloader_num_workers",
309
type=int,
310
default=0,
311
help=(
312
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
313
),
314
)
315
parser.add_argument(
316
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
317
)
318
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
319
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
320
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
321
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
322
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
323
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
324
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
325
parser.add_argument(
326
"--hub_model_id",
327
type=str,
328
default=None,
329
help="The name of the repository to keep in sync with the local `output_dir`.",
330
)
331
parser.add_argument(
332
"--logging_dir",
333
type=str,
334
default="logs",
335
help=(
336
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
337
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
338
),
339
)
340
parser.add_argument(
341
"--allow_tf32",
342
action="store_true",
343
help=(
344
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
345
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
346
),
347
)
348
parser.add_argument(
349
"--report_to",
350
type=str,
351
default="tensorboard",
352
help=(
353
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
354
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
355
),
356
)
357
parser.add_argument(
358
"--mixed_precision",
359
type=str,
360
default=None,
361
choices=["no", "fp16", "bf16"],
362
help=(
363
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
364
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
365
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
366
),
367
)
368
parser.add_argument(
369
"--prior_generation_precision",
370
type=str,
371
default=None,
372
choices=["no", "fp32", "fp16", "bf16"],
373
help=(
374
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
375
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
376
),
377
)
378
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
379
parser.add_argument(
380
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
381
)
382
383
if input_args is not None:
384
args = parser.parse_args(input_args)
385
else:
386
args = parser.parse_args()
387
388
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
389
if env_local_rank != -1 and env_local_rank != args.local_rank:
390
args.local_rank = env_local_rank
391
392
if args.with_prior_preservation:
393
if args.class_data_dir is None:
394
raise ValueError("You must specify a data directory for class images.")
395
if args.class_prompt is None:
396
raise ValueError("You must specify prompt for class images.")
397
else:
398
# logger is not available yet
399
if args.class_data_dir is not None:
400
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
401
if args.class_prompt is not None:
402
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
403
404
return args
405
406
407
class DreamBoothDataset(Dataset):
408
"""
409
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
410
It pre-processes the images and the tokenizes prompts.
411
"""
412
413
def __init__(
414
self,
415
instance_data_root,
416
instance_prompt,
417
tokenizer,
418
class_data_root=None,
419
class_prompt=None,
420
class_num=None,
421
size=512,
422
center_crop=False,
423
):
424
self.size = size
425
self.center_crop = center_crop
426
self.tokenizer = tokenizer
427
428
self.instance_data_root = Path(instance_data_root)
429
if not self.instance_data_root.exists():
430
raise ValueError("Instance images root doesn't exists.")
431
432
self.instance_images_path = list(Path(instance_data_root).iterdir())
433
self.num_instance_images = len(self.instance_images_path)
434
self.instance_prompt = instance_prompt
435
self._length = self.num_instance_images
436
437
if class_data_root is not None:
438
self.class_data_root = Path(class_data_root)
439
self.class_data_root.mkdir(parents=True, exist_ok=True)
440
self.class_images_path = list(self.class_data_root.iterdir())
441
if class_num is not None:
442
self.num_class_images = min(len(self.class_images_path), class_num)
443
else:
444
self.num_class_images = len(self.class_images_path)
445
self._length = max(self.num_class_images, self.num_instance_images)
446
self.class_prompt = class_prompt
447
else:
448
self.class_data_root = None
449
450
self.image_transforms = transforms.Compose(
451
[
452
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
453
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
454
transforms.ToTensor(),
455
transforms.Normalize([0.5], [0.5]),
456
]
457
)
458
459
def __len__(self):
460
return self._length
461
462
def __getitem__(self, index):
463
example = {}
464
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
465
if not instance_image.mode == "RGB":
466
instance_image = instance_image.convert("RGB")
467
example["instance_images"] = self.image_transforms(instance_image)
468
example["instance_prompt_ids"] = self.tokenizer(
469
self.instance_prompt,
470
truncation=True,
471
padding="max_length",
472
max_length=self.tokenizer.model_max_length,
473
return_tensors="pt",
474
).input_ids
475
476
if self.class_data_root:
477
class_image = Image.open(self.class_images_path[index % self.num_class_images])
478
if not class_image.mode == "RGB":
479
class_image = class_image.convert("RGB")
480
example["class_images"] = self.image_transforms(class_image)
481
example["class_prompt_ids"] = self.tokenizer(
482
self.class_prompt,
483
truncation=True,
484
padding="max_length",
485
max_length=self.tokenizer.model_max_length,
486
return_tensors="pt",
487
).input_ids
488
489
return example
490
491
492
def collate_fn(examples, with_prior_preservation=False):
493
input_ids = [example["instance_prompt_ids"] for example in examples]
494
pixel_values = [example["instance_images"] for example in examples]
495
496
# Concat class and instance examples for prior preservation.
497
# We do this to avoid doing two forward passes.
498
if with_prior_preservation:
499
input_ids += [example["class_prompt_ids"] for example in examples]
500
pixel_values += [example["class_images"] for example in examples]
501
502
pixel_values = torch.stack(pixel_values)
503
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
504
505
input_ids = torch.cat(input_ids, dim=0)
506
507
batch = {
508
"input_ids": input_ids,
509
"pixel_values": pixel_values,
510
}
511
return batch
512
513
514
class PromptDataset(Dataset):
515
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
516
517
def __init__(self, prompt, num_samples):
518
self.prompt = prompt
519
self.num_samples = num_samples
520
521
def __len__(self):
522
return self.num_samples
523
524
def __getitem__(self, index):
525
example = {}
526
example["prompt"] = self.prompt
527
example["index"] = index
528
return example
529
530
531
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
532
if token is None:
533
token = HfFolder.get_token()
534
if organization is None:
535
username = whoami(token)["name"]
536
return f"{username}/{model_id}"
537
else:
538
return f"{organization}/{model_id}"
539
540
541
def main(args):
542
logging_dir = Path(args.output_dir, args.logging_dir)
543
544
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
545
546
accelerator = Accelerator(
547
gradient_accumulation_steps=args.gradient_accumulation_steps,
548
mixed_precision=args.mixed_precision,
549
log_with=args.report_to,
550
logging_dir=logging_dir,
551
project_config=accelerator_project_config,
552
)
553
554
if args.report_to == "wandb":
555
if not is_wandb_available():
556
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
557
import wandb
558
559
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
560
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
561
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
562
# Make one log on every process with the configuration for debugging.
563
logging.basicConfig(
564
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
565
datefmt="%m/%d/%Y %H:%M:%S",
566
level=logging.INFO,
567
)
568
logger.info(accelerator.state, main_process_only=False)
569
if accelerator.is_local_main_process:
570
transformers.utils.logging.set_verbosity_warning()
571
diffusers.utils.logging.set_verbosity_info()
572
else:
573
transformers.utils.logging.set_verbosity_error()
574
diffusers.utils.logging.set_verbosity_error()
575
576
# If passed along, set the training seed now.
577
if args.seed is not None:
578
set_seed(args.seed)
579
580
# Generate class images if prior preservation is enabled.
581
if args.with_prior_preservation:
582
class_images_dir = Path(args.class_data_dir)
583
if not class_images_dir.exists():
584
class_images_dir.mkdir(parents=True)
585
cur_class_images = len(list(class_images_dir.iterdir()))
586
587
if cur_class_images < args.num_class_images:
588
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
589
if args.prior_generation_precision == "fp32":
590
torch_dtype = torch.float32
591
elif args.prior_generation_precision == "fp16":
592
torch_dtype = torch.float16
593
elif args.prior_generation_precision == "bf16":
594
torch_dtype = torch.bfloat16
595
pipeline = DiffusionPipeline.from_pretrained(
596
args.pretrained_model_name_or_path,
597
torch_dtype=torch_dtype,
598
safety_checker=None,
599
revision=args.revision,
600
)
601
pipeline.set_progress_bar_config(disable=True)
602
603
num_new_images = args.num_class_images - cur_class_images
604
logger.info(f"Number of class images to sample: {num_new_images}.")
605
606
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
607
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
608
609
sample_dataloader = accelerator.prepare(sample_dataloader)
610
pipeline.to(accelerator.device)
611
612
for example in tqdm(
613
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
614
):
615
images = pipeline(example["prompt"]).images
616
617
for i, image in enumerate(images):
618
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
619
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
620
image.save(image_filename)
621
622
del pipeline
623
if torch.cuda.is_available():
624
torch.cuda.empty_cache()
625
626
# Handle the repository creation
627
if accelerator.is_main_process:
628
if args.push_to_hub:
629
if args.hub_model_id is None:
630
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
631
else:
632
repo_name = args.hub_model_id
633
634
create_repo(repo_name, exist_ok=True, token=args.hub_token)
635
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
636
637
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
638
if "step_*" not in gitignore:
639
gitignore.write("step_*\n")
640
if "epoch_*" not in gitignore:
641
gitignore.write("epoch_*\n")
642
elif args.output_dir is not None:
643
os.makedirs(args.output_dir, exist_ok=True)
644
645
# Load the tokenizer
646
if args.tokenizer_name:
647
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
648
elif args.pretrained_model_name_or_path:
649
tokenizer = AutoTokenizer.from_pretrained(
650
args.pretrained_model_name_or_path,
651
subfolder="tokenizer",
652
revision=args.revision,
653
use_fast=False,
654
)
655
656
# import correct text encoder class
657
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
658
659
# Load scheduler and models
660
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
661
text_encoder = text_encoder_cls.from_pretrained(
662
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
663
)
664
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
665
unet = UNet2DConditionModel.from_pretrained(
666
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
667
)
668
669
# We only train the additional adapter LoRA layers
670
vae.requires_grad_(False)
671
text_encoder.requires_grad_(False)
672
unet.requires_grad_(False)
673
674
# For mixed precision training we cast the text_encoder and vae weights to half-precision
675
# as these models are only used for inference, keeping weights in full precision is not required.
676
weight_dtype = torch.float32
677
if accelerator.mixed_precision == "fp16":
678
weight_dtype = torch.float16
679
elif accelerator.mixed_precision == "bf16":
680
weight_dtype = torch.bfloat16
681
682
# Move unet, vae and text_encoder to device and cast to weight_dtype
683
unet.to(accelerator.device, dtype=weight_dtype)
684
vae.to(accelerator.device, dtype=weight_dtype)
685
text_encoder.to(accelerator.device, dtype=weight_dtype)
686
687
if args.enable_xformers_memory_efficient_attention:
688
if is_xformers_available():
689
import xformers
690
691
xformers_version = version.parse(xformers.__version__)
692
if xformers_version == version.parse("0.0.16"):
693
logger.warn(
694
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
695
)
696
unet.enable_xformers_memory_efficient_attention()
697
else:
698
raise ValueError("xformers is not available. Make sure it is installed correctly")
699
700
# now we will add new LoRA weights to the attention layers
701
# It's important to realize here how many attention weights will be added and of which sizes
702
# The sizes of the attention layers consist only of two different variables:
703
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
704
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
705
706
# Let's first see how many attention processors we will have to set.
707
# For Stable Diffusion, it should be equal to:
708
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
709
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
710
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
711
# => 32 layers
712
713
# Set correct lora layers
714
lora_attn_procs = {}
715
for name in unet.attn_processors.keys():
716
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
717
if name.startswith("mid_block"):
718
hidden_size = unet.config.block_out_channels[-1]
719
elif name.startswith("up_blocks"):
720
block_id = int(name[len("up_blocks.")])
721
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
722
elif name.startswith("down_blocks"):
723
block_id = int(name[len("down_blocks.")])
724
hidden_size = unet.config.block_out_channels[block_id]
725
726
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
727
728
unet.set_attn_processor(lora_attn_procs)
729
lora_layers = AttnProcsLayers(unet.attn_processors)
730
731
accelerator.register_for_checkpointing(lora_layers)
732
733
if args.scale_lr:
734
args.learning_rate = (
735
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
736
)
737
738
# Enable TF32 for faster training on Ampere GPUs,
739
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
740
if args.allow_tf32:
741
torch.backends.cuda.matmul.allow_tf32 = True
742
743
if args.scale_lr:
744
args.learning_rate = (
745
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
746
)
747
748
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
749
if args.use_8bit_adam:
750
try:
751
import bitsandbytes as bnb
752
except ImportError:
753
raise ImportError(
754
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
755
)
756
757
optimizer_class = bnb.optim.AdamW8bit
758
else:
759
optimizer_class = torch.optim.AdamW
760
761
# Optimizer creation
762
optimizer = optimizer_class(
763
lora_layers.parameters(),
764
lr=args.learning_rate,
765
betas=(args.adam_beta1, args.adam_beta2),
766
weight_decay=args.adam_weight_decay,
767
eps=args.adam_epsilon,
768
)
769
770
# Dataset and DataLoaders creation:
771
train_dataset = DreamBoothDataset(
772
instance_data_root=args.instance_data_dir,
773
instance_prompt=args.instance_prompt,
774
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
775
class_prompt=args.class_prompt,
776
class_num=args.num_class_images,
777
tokenizer=tokenizer,
778
size=args.resolution,
779
center_crop=args.center_crop,
780
)
781
782
train_dataloader = torch.utils.data.DataLoader(
783
train_dataset,
784
batch_size=args.train_batch_size,
785
shuffle=True,
786
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
787
num_workers=args.dataloader_num_workers,
788
)
789
790
# Scheduler and math around the number of training steps.
791
overrode_max_train_steps = False
792
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
793
if args.max_train_steps is None:
794
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
795
overrode_max_train_steps = True
796
797
lr_scheduler = get_scheduler(
798
args.lr_scheduler,
799
optimizer=optimizer,
800
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
801
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
802
num_cycles=args.lr_num_cycles,
803
power=args.lr_power,
804
)
805
806
# Prepare everything with our `accelerator`.
807
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
808
lora_layers, optimizer, train_dataloader, lr_scheduler
809
)
810
811
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
812
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
813
if overrode_max_train_steps:
814
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
815
# Afterwards we recalculate our number of training epochs
816
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
817
818
# We need to initialize the trackers we use, and also store our configuration.
819
# The trackers initializes automatically on the main process.
820
if accelerator.is_main_process:
821
accelerator.init_trackers("dreambooth-lora", config=vars(args))
822
823
# Train!
824
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
825
826
logger.info("***** Running training *****")
827
logger.info(f" Num examples = {len(train_dataset)}")
828
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
829
logger.info(f" Num Epochs = {args.num_train_epochs}")
830
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
831
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
832
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
833
logger.info(f" Total optimization steps = {args.max_train_steps}")
834
global_step = 0
835
first_epoch = 0
836
837
# Potentially load in the weights and states from a previous save
838
if args.resume_from_checkpoint:
839
if args.resume_from_checkpoint != "latest":
840
path = os.path.basename(args.resume_from_checkpoint)
841
else:
842
# Get the mos recent checkpoint
843
dirs = os.listdir(args.output_dir)
844
dirs = [d for d in dirs if d.startswith("checkpoint")]
845
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
846
path = dirs[-1] if len(dirs) > 0 else None
847
848
if path is None:
849
accelerator.print(
850
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
851
)
852
args.resume_from_checkpoint = None
853
else:
854
accelerator.print(f"Resuming from checkpoint {path}")
855
accelerator.load_state(os.path.join(args.output_dir, path))
856
global_step = int(path.split("-")[1])
857
858
resume_global_step = global_step * args.gradient_accumulation_steps
859
first_epoch = global_step // num_update_steps_per_epoch
860
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
861
862
# Only show the progress bar once on each machine.
863
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
864
progress_bar.set_description("Steps")
865
866
for epoch in range(first_epoch, args.num_train_epochs):
867
unet.train()
868
for step, batch in enumerate(train_dataloader):
869
# Skip steps until we reach the resumed step
870
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
871
if step % args.gradient_accumulation_steps == 0:
872
progress_bar.update(1)
873
continue
874
875
with accelerator.accumulate(unet):
876
# Convert images to latent space
877
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
878
latents = latents * vae.config.scaling_factor
879
880
# Sample noise that we'll add to the latents
881
noise = torch.randn_like(latents)
882
bsz = latents.shape[0]
883
# Sample a random timestep for each image
884
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
885
timesteps = timesteps.long()
886
887
# Add noise to the latents according to the noise magnitude at each timestep
888
# (this is the forward diffusion process)
889
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
890
891
# Get the text embedding for conditioning
892
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
893
894
# Predict the noise residual
895
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
896
897
# Get the target for loss depending on the prediction type
898
if noise_scheduler.config.prediction_type == "epsilon":
899
target = noise
900
elif noise_scheduler.config.prediction_type == "v_prediction":
901
target = noise_scheduler.get_velocity(latents, noise, timesteps)
902
else:
903
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
904
905
if args.with_prior_preservation:
906
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
907
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
908
target, target_prior = torch.chunk(target, 2, dim=0)
909
910
# Compute instance loss
911
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
912
913
# Compute prior loss
914
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
915
916
# Add the prior loss to the instance loss.
917
loss = loss + args.prior_loss_weight * prior_loss
918
else:
919
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
920
921
accelerator.backward(loss)
922
if accelerator.sync_gradients:
923
params_to_clip = lora_layers.parameters()
924
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
925
optimizer.step()
926
lr_scheduler.step()
927
optimizer.zero_grad()
928
929
# Checks if the accelerator has performed an optimization step behind the scenes
930
if accelerator.sync_gradients:
931
progress_bar.update(1)
932
global_step += 1
933
934
if global_step % args.checkpointing_steps == 0:
935
if accelerator.is_main_process:
936
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
937
accelerator.save_state(save_path)
938
logger.info(f"Saved state to {save_path}")
939
940
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
941
progress_bar.set_postfix(**logs)
942
accelerator.log(logs, step=global_step)
943
944
if global_step >= args.max_train_steps:
945
break
946
947
if accelerator.is_main_process:
948
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
949
logger.info(
950
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
951
f" {args.validation_prompt}."
952
)
953
# create pipeline
954
pipeline = DiffusionPipeline.from_pretrained(
955
args.pretrained_model_name_or_path,
956
unet=accelerator.unwrap_model(unet),
957
text_encoder=accelerator.unwrap_model(text_encoder),
958
revision=args.revision,
959
torch_dtype=weight_dtype,
960
)
961
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
962
pipeline = pipeline.to(accelerator.device)
963
pipeline.set_progress_bar_config(disable=True)
964
965
# run inference
966
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
967
images = [
968
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
969
for _ in range(args.num_validation_images)
970
]
971
972
for tracker in accelerator.trackers:
973
if tracker.name == "tensorboard":
974
np_images = np.stack([np.asarray(img) for img in images])
975
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
976
if tracker.name == "wandb":
977
tracker.log(
978
{
979
"validation": [
980
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
981
for i, image in enumerate(images)
982
]
983
}
984
)
985
986
del pipeline
987
torch.cuda.empty_cache()
988
989
# Save the lora layers
990
accelerator.wait_for_everyone()
991
if accelerator.is_main_process:
992
unet = unet.to(torch.float32)
993
unet.save_attn_procs(args.output_dir)
994
995
# Final inference
996
# Load previous pipeline
997
pipeline = DiffusionPipeline.from_pretrained(
998
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
999
)
1000
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1001
pipeline = pipeline.to(accelerator.device)
1002
1003
# load attention processors
1004
pipeline.unet.load_attn_procs(args.output_dir)
1005
1006
# run inference
1007
if args.validation_prompt and args.num_validation_images > 0:
1008
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1009
images = [
1010
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1011
for _ in range(args.num_validation_images)
1012
]
1013
1014
for tracker in accelerator.trackers:
1015
if tracker.name == "tensorboard":
1016
np_images = np.stack([np.asarray(img) for img in images])
1017
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1018
if tracker.name == "wandb":
1019
tracker.log(
1020
{
1021
"test": [
1022
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1023
for i, image in enumerate(images)
1024
]
1025
}
1026
)
1027
1028
if args.push_to_hub:
1029
save_model_card(
1030
repo_name,
1031
images=images,
1032
base_model=args.pretrained_model_name_or_path,
1033
prompt=args.instance_prompt,
1034
repo_folder=args.output_dir,
1035
)
1036
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1037
1038
accelerator.end_training()
1039
1040
1041
if __name__ == "__main__":
1042
args = parse_args()
1043
main(args)
1044