Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/dreambooth/train_dreambooth.py
1441 views
1
import argparse
2
import hashlib
3
import itertools
4
import random
5
import json
6
import logging
7
import math
8
import os
9
from contextlib import nullcontext
10
from pathlib import Path
11
from typing import Optional
12
13
import torch
14
import torch.nn.functional as F
15
import torch.utils.checkpoint
16
from torch.utils.data import Dataset
17
18
from accelerate import Accelerator
19
from accelerate.logging import get_logger
20
from accelerate.utils import set_seed
21
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
22
from diffusers.optimization import get_scheduler
23
from diffusers.utils.import_utils import is_xformers_available
24
from huggingface_hub import HfFolder, Repository, whoami
25
from PIL import Image
26
from torchvision import transforms
27
from tqdm.auto import tqdm
28
from transformers import CLIPTextModel, CLIPTokenizer
29
30
31
torch.backends.cudnn.benchmark = True
32
33
34
logger = get_logger(__name__)
35
36
37
def parse_args(input_args=None):
38
parser = argparse.ArgumentParser(description="Simple example of a training script.")
39
parser.add_argument(
40
"--pretrained_model_name_or_path",
41
type=str,
42
default=None,
43
required=True,
44
help="Path to pretrained model or model identifier from huggingface.co/models.",
45
)
46
parser.add_argument(
47
"--pretrained_vae_name_or_path",
48
type=str,
49
default=None,
50
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
51
)
52
parser.add_argument(
53
"--revision",
54
type=str,
55
default=None,
56
required=False,
57
help="Revision of pretrained model identifier from huggingface.co/models.",
58
)
59
parser.add_argument(
60
"--tokenizer_name",
61
type=str,
62
default=None,
63
help="Pretrained tokenizer name or path if not the same as model_name",
64
)
65
parser.add_argument(
66
"--instance_data_dir",
67
type=str,
68
default=None,
69
help="A folder containing the training data of instance images.",
70
)
71
parser.add_argument(
72
"--class_data_dir",
73
type=str,
74
default=None,
75
help="A folder containing the training data of class images.",
76
)
77
parser.add_argument(
78
"--instance_prompt",
79
type=str,
80
default=None,
81
help="The prompt with identifier specifying the instance",
82
)
83
parser.add_argument(
84
"--class_prompt",
85
type=str,
86
default=None,
87
help="The prompt to specify images in the same class as provided instance images.",
88
)
89
parser.add_argument(
90
"--save_sample_prompt",
91
type=str,
92
default=None,
93
help="The prompt used to generate sample outputs to save.",
94
)
95
parser.add_argument(
96
"--save_sample_negative_prompt",
97
type=str,
98
default=None,
99
help="The negative prompt used to generate sample outputs to save.",
100
)
101
parser.add_argument(
102
"--n_save_sample",
103
type=int,
104
default=4,
105
help="The number of samples to save.",
106
)
107
parser.add_argument(
108
"--save_guidance_scale",
109
type=float,
110
default=7.5,
111
help="CFG for save sample.",
112
)
113
parser.add_argument(
114
"--save_infer_steps",
115
type=int,
116
default=20,
117
help="The number of inference steps for save sample.",
118
)
119
parser.add_argument(
120
"--pad_tokens",
121
default=False,
122
action="store_true",
123
help="Flag to pad tokens to length 77.",
124
)
125
parser.add_argument(
126
"--with_prior_preservation",
127
default=False,
128
action="store_true",
129
help="Flag to add prior preservation loss.",
130
)
131
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
132
parser.add_argument(
133
"--num_class_images",
134
type=int,
135
default=100,
136
help=(
137
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
138
" sampled with class_prompt."
139
),
140
)
141
parser.add_argument(
142
"--output_dir",
143
type=str,
144
default="text-inversion-model",
145
help="The output directory where the model predictions and checkpoints will be written.",
146
)
147
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
148
parser.add_argument(
149
"--resolution",
150
type=int,
151
default=512,
152
help=(
153
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
154
" resolution"
155
),
156
)
157
parser.add_argument(
158
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
159
)
160
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
161
parser.add_argument(
162
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
163
)
164
parser.add_argument(
165
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
166
)
167
parser.add_argument("--num_train_epochs", type=int, default=1)
168
parser.add_argument(
169
"--max_train_steps",
170
type=int,
171
default=None,
172
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
173
)
174
parser.add_argument(
175
"--gradient_accumulation_steps",
176
type=int,
177
default=1,
178
help="Number of updates steps to accumulate before performing a backward/update pass.",
179
)
180
parser.add_argument(
181
"--gradient_checkpointing",
182
action="store_true",
183
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
184
)
185
parser.add_argument(
186
"--learning_rate",
187
type=float,
188
default=5e-6,
189
help="Initial learning rate (after the potential warmup period) to use.",
190
)
191
parser.add_argument(
192
"--scale_lr",
193
action="store_true",
194
default=False,
195
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
196
)
197
parser.add_argument(
198
"--lr_scheduler",
199
type=str,
200
default="constant",
201
help=(
202
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
203
' "constant", "constant_with_warmup"]'
204
),
205
)
206
parser.add_argument(
207
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
208
)
209
parser.add_argument(
210
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
211
)
212
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
213
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
214
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
215
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
216
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
217
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
218
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
219
parser.add_argument(
220
"--hub_model_id",
221
type=str,
222
default=None,
223
help="The name of the repository to keep in sync with the local `output_dir`.",
224
)
225
parser.add_argument(
226
"--logging_dir",
227
type=str,
228
default="logs",
229
help=(
230
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
231
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
232
),
233
)
234
parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
235
parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
236
parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
237
parser.add_argument(
238
"--mixed_precision",
239
type=str,
240
default=None,
241
choices=["no", "fp16", "bf16"],
242
help=(
243
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
244
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
245
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
246
),
247
)
248
parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
249
parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
250
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
251
parser.add_argument(
252
"--concepts_list",
253
type=str,
254
default=None,
255
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
256
)
257
parser.add_argument(
258
"--read_prompts_from_txts",
259
action="store_true",
260
help="Use prompt per image. Put prompts in the same directory as images, e.g. for image.png create image.png.txt.",
261
)
262
263
if input_args is not None:
264
args = parser.parse_args(input_args)
265
else:
266
args = parser.parse_args()
267
268
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
269
if env_local_rank != -1 and env_local_rank != args.local_rank:
270
args.local_rank = env_local_rank
271
272
return args
273
274
275
class DreamBoothDataset(Dataset):
276
"""
277
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
278
It pre-processes the images and the tokenizes prompts.
279
"""
280
281
def __init__(
282
self,
283
concepts_list,
284
tokenizer,
285
with_prior_preservation=True,
286
size=512,
287
center_crop=False,
288
num_class_images=None,
289
pad_tokens=False,
290
hflip=False,
291
read_prompts_from_txts=False,
292
):
293
self.size = size
294
self.center_crop = center_crop
295
self.tokenizer = tokenizer
296
self.with_prior_preservation = with_prior_preservation
297
self.pad_tokens = pad_tokens
298
self.read_prompts_from_txts = read_prompts_from_txts
299
300
self.instance_images_path = []
301
self.class_images_path = []
302
303
for concept in concepts_list:
304
inst_img_path = [
305
(x, concept["instance_prompt"])
306
for x in Path(concept["instance_data_dir"]).iterdir()
307
if x.is_file() and not str(x).endswith(".txt")
308
]
309
self.instance_images_path.extend(inst_img_path)
310
311
if with_prior_preservation:
312
class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
313
self.class_images_path.extend(class_img_path[:num_class_images])
314
315
random.shuffle(self.instance_images_path)
316
self.num_instance_images = len(self.instance_images_path)
317
self.num_class_images = len(self.class_images_path)
318
self._length = max(self.num_class_images, self.num_instance_images)
319
320
self.image_transforms = transforms.Compose(
321
[
322
transforms.RandomHorizontalFlip(0.5 * hflip),
323
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
324
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
325
transforms.ToTensor(),
326
transforms.Normalize([0.5], [0.5]),
327
]
328
)
329
330
def __len__(self):
331
return self._length
332
333
def __getitem__(self, index):
334
example = {}
335
instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
336
337
if self.read_prompts_from_txts:
338
with open(str(instance_path) + ".txt") as f:
339
instance_prompt = f.read().strip()
340
341
instance_image = Image.open(instance_path)
342
if not instance_image.mode == "RGB":
343
instance_image = instance_image.convert("RGB")
344
345
example["instance_images"] = self.image_transforms(instance_image)
346
example["instance_prompt_ids"] = self.tokenizer(
347
instance_prompt,
348
padding="max_length" if self.pad_tokens else "do_not_pad",
349
truncation=True,
350
max_length=self.tokenizer.model_max_length,
351
).input_ids
352
353
if self.with_prior_preservation:
354
class_path, class_prompt = self.class_images_path[index % self.num_class_images]
355
class_image = Image.open(class_path)
356
if not class_image.mode == "RGB":
357
class_image = class_image.convert("RGB")
358
example["class_images"] = self.image_transforms(class_image)
359
example["class_prompt_ids"] = self.tokenizer(
360
class_prompt,
361
padding="max_length" if self.pad_tokens else "do_not_pad",
362
truncation=True,
363
max_length=self.tokenizer.model_max_length,
364
).input_ids
365
366
return example
367
368
369
class PromptDataset(Dataset):
370
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
371
372
def __init__(self, prompt, num_samples):
373
self.prompt = prompt
374
self.num_samples = num_samples
375
376
def __len__(self):
377
return self.num_samples
378
379
def __getitem__(self, index):
380
example = {}
381
example["prompt"] = self.prompt
382
example["index"] = index
383
return example
384
385
386
class LatentsDataset(Dataset):
387
def __init__(self, latents_cache, text_encoder_cache):
388
self.latents_cache = latents_cache
389
self.text_encoder_cache = text_encoder_cache
390
391
def __len__(self):
392
return len(self.latents_cache)
393
394
def __getitem__(self, index):
395
return self.latents_cache[index], self.text_encoder_cache[index]
396
397
398
class AverageMeter:
399
def __init__(self, name=None):
400
self.name = name
401
self.reset()
402
403
def reset(self):
404
self.sum = self.count = self.avg = 0
405
406
def update(self, val, n=1):
407
self.sum += val * n
408
self.count += n
409
self.avg = self.sum / self.count
410
411
412
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
413
if token is None:
414
token = HfFolder.get_token()
415
if organization is None:
416
username = whoami(token)["name"]
417
return f"{username}/{model_id}"
418
else:
419
return f"{organization}/{model_id}"
420
421
422
def main(args):
423
logging_dir = Path(args.output_dir, "0", args.logging_dir)
424
425
accelerator = Accelerator(
426
gradient_accumulation_steps=args.gradient_accumulation_steps,
427
mixed_precision=args.mixed_precision,
428
log_with="tensorboard",
429
project_dir=logging_dir,
430
)
431
432
logging.basicConfig(
433
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
434
datefmt="%m/%d/%Y %H:%M:%S",
435
level=logging.INFO,
436
)
437
438
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
439
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
440
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
441
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
442
raise ValueError(
443
"Gradient accumulation is not supported when training the text encoder in distributed training. "
444
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
445
)
446
447
if args.seed is not None:
448
set_seed(args.seed)
449
450
if args.concepts_list is None:
451
args.concepts_list = [
452
{
453
"instance_prompt": args.instance_prompt,
454
"class_prompt": args.class_prompt,
455
"instance_data_dir": args.instance_data_dir,
456
"class_data_dir": args.class_data_dir
457
}
458
]
459
else:
460
with open(args.concepts_list, "r") as f:
461
args.concepts_list = json.load(f)
462
463
if args.with_prior_preservation:
464
pipeline = None
465
for concept in args.concepts_list:
466
class_images_dir = Path(concept["class_data_dir"])
467
class_images_dir.mkdir(parents=True, exist_ok=True)
468
cur_class_images = len(list(class_images_dir.iterdir()))
469
470
if cur_class_images < args.num_class_images:
471
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
472
if pipeline is None:
473
pipeline = StableDiffusionPipeline.from_pretrained(
474
args.pretrained_model_name_or_path,
475
vae=AutoencoderKL.from_pretrained(
476
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
477
subfolder=None if args.pretrained_vae_name_or_path else "vae",
478
revision=None if args.pretrained_vae_name_or_path else args.revision,
479
torch_dtype=torch_dtype
480
),
481
torch_dtype=torch_dtype,
482
safety_checker=None,
483
revision=args.revision
484
)
485
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
486
if is_xformers_available():
487
pipeline.enable_xformers_memory_efficient_attention()
488
pipeline.set_progress_bar_config(disable=True)
489
pipeline.to(accelerator.device)
490
491
num_new_images = args.num_class_images - cur_class_images
492
logger.info(f"Number of class images to sample: {num_new_images}.")
493
494
sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
495
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
496
497
sample_dataloader = accelerator.prepare(sample_dataloader)
498
499
with torch.autocast("cuda"), torch.inference_mode():
500
for example in tqdm(
501
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
502
):
503
images = pipeline(
504
example["prompt"],
505
num_inference_steps=args.save_infer_steps
506
).images
507
508
for i, image in enumerate(images):
509
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
510
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
511
image.save(image_filename)
512
513
del pipeline
514
if torch.cuda.is_available():
515
torch.cuda.empty_cache()
516
517
# Load the tokenizer
518
if args.tokenizer_name:
519
tokenizer = CLIPTokenizer.from_pretrained(
520
args.tokenizer_name,
521
revision=args.revision,
522
)
523
elif args.pretrained_model_name_or_path:
524
tokenizer = CLIPTokenizer.from_pretrained(
525
args.pretrained_model_name_or_path,
526
subfolder="tokenizer",
527
revision=args.revision,
528
)
529
530
# Load models and create wrapper for stable diffusion
531
text_encoder = CLIPTextModel.from_pretrained(
532
args.pretrained_model_name_or_path,
533
subfolder="text_encoder",
534
revision=args.revision,
535
)
536
vae = AutoencoderKL.from_pretrained(
537
args.pretrained_model_name_or_path,
538
subfolder="vae",
539
revision=args.revision,
540
)
541
unet = UNet2DConditionModel.from_pretrained(
542
args.pretrained_model_name_or_path,
543
subfolder="unet",
544
revision=args.revision,
545
torch_dtype=torch.float32
546
)
547
548
vae.requires_grad_(False)
549
if not args.train_text_encoder:
550
text_encoder.requires_grad_(False)
551
552
if is_xformers_available():
553
vae.enable_xformers_memory_efficient_attention()
554
unet.enable_xformers_memory_efficient_attention()
555
else:
556
logger.warning("xformers is not available. Make sure it is installed correctly")
557
558
if args.gradient_checkpointing:
559
unet.enable_gradient_checkpointing()
560
if args.train_text_encoder:
561
text_encoder.gradient_checkpointing_enable()
562
563
if args.scale_lr:
564
args.learning_rate = (
565
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
566
)
567
568
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
569
if args.use_8bit_adam:
570
try:
571
import bitsandbytes as bnb
572
except ImportError:
573
raise ImportError(
574
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
575
)
576
577
optimizer_class = bnb.optim.AdamW8bit
578
else:
579
optimizer_class = torch.optim.AdamW
580
581
params_to_optimize = (
582
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
583
)
584
optimizer = optimizer_class(
585
params_to_optimize,
586
lr=args.learning_rate,
587
betas=(args.adam_beta1, args.adam_beta2),
588
weight_decay=args.adam_weight_decay,
589
eps=args.adam_epsilon,
590
)
591
592
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
593
594
train_dataset = DreamBoothDataset(
595
concepts_list=args.concepts_list,
596
tokenizer=tokenizer,
597
with_prior_preservation=args.with_prior_preservation,
598
size=args.resolution,
599
center_crop=args.center_crop,
600
num_class_images=args.num_class_images,
601
pad_tokens=args.pad_tokens,
602
hflip=args.hflip,
603
read_prompts_from_txts=args.read_prompts_from_txts,
604
)
605
606
def collate_fn(examples):
607
input_ids = [example["instance_prompt_ids"] for example in examples]
608
pixel_values = [example["instance_images"] for example in examples]
609
610
# Concat class and instance examples for prior preservation.
611
# We do this to avoid doing two forward passes.
612
if args.with_prior_preservation:
613
input_ids += [example["class_prompt_ids"] for example in examples]
614
pixel_values += [example["class_images"] for example in examples]
615
616
pixel_values = torch.stack(pixel_values)
617
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
618
619
input_ids = tokenizer.pad(
620
{"input_ids": input_ids},
621
padding=True,
622
return_tensors="pt",
623
).input_ids
624
625
batch = {
626
"input_ids": input_ids,
627
"pixel_values": pixel_values,
628
}
629
return batch
630
631
train_dataloader = torch.utils.data.DataLoader(
632
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
633
)
634
635
weight_dtype = torch.float32
636
if args.mixed_precision == "fp16":
637
weight_dtype = torch.float16
638
elif args.mixed_precision == "bf16":
639
weight_dtype = torch.bfloat16
640
641
# Move text_encode and vae to gpu.
642
# For mixed precision training we cast the text_encoder and vae weights to half-precision
643
# as these models are only used for inference, keeping weights in full precision is not required.
644
vae.to(accelerator.device, dtype=weight_dtype)
645
if not args.train_text_encoder:
646
text_encoder.to(accelerator.device, dtype=weight_dtype)
647
648
if not args.not_cache_latents:
649
latents_cache = []
650
text_encoder_cache = []
651
for batch in tqdm(train_dataloader, desc="Caching latents"):
652
with torch.no_grad():
653
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
654
batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
655
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
656
if args.train_text_encoder:
657
text_encoder_cache.append(batch["input_ids"])
658
else:
659
text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
660
train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
661
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
662
663
del vae
664
if not args.train_text_encoder:
665
del text_encoder
666
if torch.cuda.is_available():
667
torch.cuda.empty_cache()
668
669
# Scheduler and math around the number of training steps.
670
overrode_max_train_steps = False
671
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
672
if args.max_train_steps is None:
673
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
674
overrode_max_train_steps = True
675
676
lr_scheduler = get_scheduler(
677
args.lr_scheduler,
678
optimizer=optimizer,
679
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
680
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
681
)
682
683
if args.train_text_encoder:
684
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
685
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
686
)
687
else:
688
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
689
unet, optimizer, train_dataloader, lr_scheduler
690
)
691
692
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
693
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
694
if overrode_max_train_steps:
695
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
696
# Afterwards we recalculate our number of training epochs
697
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
698
699
# We need to initialize the trackers we use, and also store our configuration.
700
# The trackers initializes automatically on the main process.
701
if accelerator.is_main_process:
702
accelerator.init_trackers("dreambooth")
703
704
# Train!
705
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
706
707
logger.info("***** Running training *****")
708
logger.info(f" Num examples = {len(train_dataset)}")
709
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
710
logger.info(f" Num Epochs = {args.num_train_epochs}")
711
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
712
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
713
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
714
logger.info(f" Total optimization steps = {args.max_train_steps}")
715
716
def save_weights(step):
717
# Create the pipeline using using the trained modules and save it.
718
if accelerator.is_main_process:
719
if args.train_text_encoder:
720
text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
721
else:
722
text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
723
pipeline = StableDiffusionPipeline.from_pretrained(
724
args.pretrained_model_name_or_path,
725
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
726
text_encoder=text_enc_model,
727
vae=AutoencoderKL.from_pretrained(
728
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
729
subfolder=None if args.pretrained_vae_name_or_path else "vae",
730
revision=None if args.pretrained_vae_name_or_path else args.revision,
731
),
732
safety_checker=None,
733
torch_dtype=torch.float16,
734
revision=args.revision,
735
)
736
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
737
if is_xformers_available():
738
pipeline.enable_xformers_memory_efficient_attention()
739
save_dir = os.path.join(args.output_dir, f"{step}")
740
pipeline.save_pretrained(save_dir)
741
with open(os.path.join(save_dir, "args.json"), "w") as f:
742
json.dump(args.__dict__, f, indent=2)
743
744
if args.save_sample_prompt is not None:
745
pipeline = pipeline.to(accelerator.device)
746
g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)
747
pipeline.set_progress_bar_config(disable=True)
748
sample_dir = os.path.join(save_dir, "samples")
749
os.makedirs(sample_dir, exist_ok=True)
750
with torch.autocast("cuda"), torch.inference_mode():
751
for i in tqdm(range(args.n_save_sample), desc="Generating samples"):
752
images = pipeline(
753
args.save_sample_prompt,
754
negative_prompt=args.save_sample_negative_prompt,
755
guidance_scale=args.save_guidance_scale,
756
num_inference_steps=args.save_infer_steps,
757
generator=g_cuda
758
).images
759
images[0].save(os.path.join(sample_dir, f"{i}.png"))
760
del pipeline
761
if torch.cuda.is_available():
762
torch.cuda.empty_cache()
763
print(f"[*] Weights saved at {save_dir}")
764
765
# Only show the progress bar once on each machine.
766
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
767
progress_bar.set_description("Steps")
768
global_step = 0
769
loss_avg = AverageMeter()
770
text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
771
for epoch in range(args.num_train_epochs):
772
unet.train()
773
if args.train_text_encoder:
774
text_encoder.train()
775
for step, batch in enumerate(train_dataloader):
776
with accelerator.accumulate(unet):
777
# Convert images to latent space
778
with torch.no_grad():
779
if not args.not_cache_latents:
780
latent_dist = batch[0][0]
781
else:
782
latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
783
latents = latent_dist.sample() * 0.18215
784
785
# Sample noise that we'll add to the latents
786
noise = torch.randn_like(latents)
787
bsz = latents.shape[0]
788
# Sample a random timestep for each image
789
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
790
timesteps = timesteps.long()
791
792
# Add noise to the latents according to the noise magnitude at each timestep
793
# (this is the forward diffusion process)
794
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
795
796
# Get the text embedding for conditioning
797
with text_enc_context:
798
if not args.not_cache_latents:
799
if args.train_text_encoder:
800
encoder_hidden_states = text_encoder(batch[0][1])[0]
801
else:
802
encoder_hidden_states = batch[0][1]
803
else:
804
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
805
806
# Predict the noise residual
807
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
808
809
# Get the target for loss depending on the prediction type
810
if noise_scheduler.config.prediction_type == "epsilon":
811
target = noise
812
elif noise_scheduler.config.prediction_type == "v_prediction":
813
target = noise_scheduler.get_velocity(latents, noise, timesteps)
814
else:
815
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
816
817
if args.with_prior_preservation:
818
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
819
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
820
target, target_prior = torch.chunk(target, 2, dim=0)
821
822
# Compute instance loss
823
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
824
825
# Compute prior loss
826
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
827
828
# Add the prior loss to the instance loss.
829
loss = loss + args.prior_loss_weight * prior_loss
830
else:
831
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
832
833
accelerator.backward(loss)
834
# if accelerator.sync_gradients:
835
# params_to_clip = (
836
# itertools.chain(unet.parameters(), text_encoder.parameters())
837
# if args.train_text_encoder
838
# else unet.parameters()
839
# )
840
# accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
841
optimizer.step()
842
lr_scheduler.step()
843
optimizer.zero_grad(set_to_none=True)
844
loss_avg.update(loss.detach_(), bsz)
845
846
if not global_step % args.log_interval:
847
logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
848
progress_bar.set_postfix(**logs)
849
accelerator.log(logs, step=global_step)
850
851
if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:
852
save_weights(global_step)
853
854
progress_bar.update(1)
855
global_step += 1
856
857
if global_step >= args.max_train_steps:
858
break
859
860
accelerator.wait_for_everyone()
861
862
save_weights(global_step)
863
864
accelerator.end_training()
865
866
867
if __name__ == "__main__":
868
args = parse_args()
869
main(args)
870
871