Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/dreambooth/train_inpainting_dreambooth.py
1441 views
1
import argparse
2
import hashlib
3
import itertools
4
import json
5
import math
6
import os
7
import random
8
import shutil
9
from contextlib import nullcontext
10
from pathlib import Path
11
from typing import Optional
12
13
import torch
14
import torch.nn.functional as F
15
import torch.utils.checkpoint
16
from accelerate import Accelerator
17
from accelerate.logging import get_logger
18
from accelerate.utils import set_seed
19
from huggingface_hub import HfFolder, Repository, whoami
20
from PIL import Image
21
from torch.utils.data import Dataset
22
from torchvision import transforms
23
from tqdm.auto import tqdm
24
from transformers import CLIPTextModel, CLIPTokenizer
25
26
from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler,
27
StableDiffusionInpaintPipeline, UNet2DConditionModel)
28
from diffusers.optimization import get_scheduler
29
30
torch.backends.cudnn.benchmark = True
31
32
33
logger = get_logger(__name__)
34
35
36
def parse_args(input_args=None):
37
parser = argparse.ArgumentParser(description="Simple example of a training script.")
38
parser.add_argument(
39
"--pretrained_model_name_or_path",
40
type=str,
41
default=None,
42
required=True,
43
help="Path to pretrained model or model identifier from huggingface.co/models.",
44
)
45
parser.add_argument(
46
"--pretrained_vae_name_or_path",
47
type=str,
48
default=None,
49
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
50
)
51
parser.add_argument(
52
"--revision",
53
type=str,
54
default="fp16",
55
required=False,
56
help="Revision of pretrained model identifier from huggingface.co/models.",
57
)
58
parser.add_argument(
59
"--tokenizer_name",
60
type=str,
61
default=None,
62
help="Pretrained tokenizer name or path if not the same as model_name",
63
)
64
parser.add_argument(
65
"--instance_data_dir",
66
type=str,
67
default=None,
68
help="A folder containing the training data of instance images.",
69
)
70
parser.add_argument(
71
"--class_data_dir",
72
type=str,
73
default=None,
74
help="A folder containing the training data of class images.",
75
)
76
parser.add_argument(
77
"--instance_prompt",
78
type=str,
79
default=None,
80
help="The prompt with identifier specifying the instance",
81
)
82
parser.add_argument(
83
"--class_prompt",
84
type=str,
85
default=None,
86
help="The prompt to specify images in the same class as provided instance images.",
87
)
88
# parser.add_argument(
89
# "--save_sample_prompt",
90
# type=str,
91
# default=None,
92
# help="The prompt used to generate sample outputs to save.",
93
# )
94
parser.add_argument(
95
"--save_sample_negative_prompt",
96
type=str,
97
default=None,
98
help="The negative prompt used to generate sample outputs to save.",
99
)
100
parser.add_argument(
101
"--n_save_sample",
102
type=int,
103
default=4,
104
help="The number of samples to save.",
105
)
106
parser.add_argument(
107
"--save_guidance_scale",
108
type=float,
109
default=7.5,
110
help="CFG for save sample.",
111
)
112
parser.add_argument(
113
"--save_infer_steps",
114
type=int,
115
default=50,
116
help="The number of inference steps for save sample.",
117
)
118
parser.add_argument(
119
"--pad_tokens",
120
default=False,
121
action="store_true",
122
help="Flag to pad tokens to length 77.",
123
)
124
parser.add_argument(
125
"--with_prior_preservation",
126
default=False,
127
action="store_true",
128
help="Flag to add prior preservation loss.",
129
)
130
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
131
parser.add_argument(
132
"--num_class_images",
133
type=int,
134
default=100,
135
help=(
136
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
137
" sampled with class_prompt."
138
),
139
)
140
parser.add_argument(
141
"--output_dir",
142
type=str,
143
default="text-inversion-model",
144
help="The output directory where the model predictions and checkpoints will be written.",
145
)
146
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
147
parser.add_argument(
148
"--resolution",
149
type=int,
150
default=512,
151
help=(
152
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
153
" resolution"
154
),
155
)
156
parser.add_argument(
157
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
158
)
159
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
160
parser.add_argument(
161
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
162
)
163
parser.add_argument(
164
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
165
)
166
parser.add_argument("--num_train_epochs", type=int, default=1)
167
parser.add_argument(
168
"--max_train_steps",
169
type=int,
170
default=None,
171
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
172
)
173
parser.add_argument(
174
"--gradient_accumulation_steps",
175
type=int,
176
default=1,
177
help="Number of updates steps to accumulate before performing a backward/update pass.",
178
)
179
parser.add_argument(
180
"--gradient_checkpointing",
181
action="store_true",
182
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
183
)
184
parser.add_argument(
185
"--learning_rate",
186
type=float,
187
default=5e-6,
188
help="Initial learning rate (after the potential warmup period) to use.",
189
)
190
parser.add_argument(
191
"--scale_lr",
192
action="store_true",
193
default=False,
194
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
195
)
196
parser.add_argument(
197
"--lr_scheduler",
198
type=str,
199
default="constant",
200
help=(
201
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
202
' "constant", "constant_with_warmup"]'
203
),
204
)
205
parser.add_argument(
206
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
207
)
208
parser.add_argument(
209
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
210
)
211
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
212
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
213
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
214
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
215
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
216
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
217
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
218
parser.add_argument(
219
"--hub_model_id",
220
type=str,
221
default=None,
222
help="The name of the repository to keep in sync with the local `output_dir`.",
223
)
224
parser.add_argument(
225
"--logging_dir",
226
type=str,
227
default="logs",
228
help=(
229
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
230
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
231
),
232
)
233
parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
234
parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
235
parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
236
parser.add_argument(
237
"--mixed_precision",
238
type=str,
239
default="no",
240
choices=["no", "fp16", "bf16"],
241
help=(
242
"Whether to use mixed precision. Choose"
243
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
244
"and an Nvidia Ampere GPU."
245
),
246
)
247
parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
248
parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
249
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
250
parser.add_argument(
251
"--concepts_list",
252
type=str,
253
default=None,
254
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
255
)
256
257
if input_args is not None:
258
args = parser.parse_args(input_args)
259
else:
260
args = parser.parse_args()
261
262
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
263
if env_local_rank != -1 and env_local_rank != args.local_rank:
264
args.local_rank = env_local_rank
265
266
return args
267
268
269
def get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128):
270
holes = []
271
for _n in range(random.randint(min_holes, max_holes)):
272
hole_height = random.randint(min_height, max_height)
273
hole_width = random.randint(min_width, max_width)
274
y1 = random.randint(0, height - hole_height)
275
x1 = random.randint(0, width - hole_width)
276
y2 = y1 + hole_height
277
x2 = x1 + hole_width
278
holes.append((x1, y1, x2, y2))
279
return holes
280
281
282
def generate_random_mask(image):
283
mask = torch.zeros_like(image[:1])
284
holes = get_cutout_holes(mask.shape[1], mask.shape[2])
285
for (x1, y1, x2, y2) in holes:
286
mask[:, y1:y2, x1:x2] = 1.
287
if random.uniform(0, 1) < 0.25:
288
mask.fill_(1.)
289
masked_image = image * (mask < 0.5)
290
return mask, masked_image
291
292
293
class DreamBoothDataset(Dataset):
294
"""
295
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
296
It pre-processes the images and the tokenizes prompts.
297
"""
298
299
def __init__(
300
self,
301
concepts_list,
302
tokenizer,
303
with_prior_preservation=True,
304
size=512,
305
center_crop=False,
306
num_class_images=None,
307
pad_tokens=False,
308
hflip=False
309
):
310
self.size = size
311
self.center_crop = center_crop
312
self.tokenizer = tokenizer
313
self.with_prior_preservation = with_prior_preservation
314
self.pad_tokens = pad_tokens
315
316
self.instance_images_path = []
317
self.class_images_path = []
318
319
for concept in concepts_list:
320
inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]
321
self.instance_images_path.extend(inst_img_path)
322
323
if with_prior_preservation:
324
class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
325
self.class_images_path.extend(class_img_path[:num_class_images])
326
327
random.shuffle(self.instance_images_path)
328
self.num_instance_images = len(self.instance_images_path)
329
self.num_class_images = len(self.class_images_path)
330
self._length = max(self.num_class_images, self.num_instance_images)
331
332
self.image_transforms = transforms.Compose(
333
[
334
transforms.RandomHorizontalFlip(0.5 * hflip),
335
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
336
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
337
transforms.ToTensor(),
338
transforms.Normalize([0.5], [0.5]),
339
]
340
)
341
342
def __len__(self):
343
return self._length
344
345
def __getitem__(self, index):
346
example = {}
347
instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
348
instance_image = Image.open(instance_path)
349
if not instance_image.mode == "RGB":
350
instance_image = instance_image.convert("RGB")
351
example["instance_images"] = self.image_transforms(instance_image)
352
example["instance_masks"], example["instance_masked_images"] = generate_random_mask(example["instance_images"])
353
example["instance_prompt_ids"] = self.tokenizer(
354
instance_prompt,
355
padding="max_length" if self.pad_tokens else "do_not_pad",
356
truncation=True,
357
max_length=self.tokenizer.model_max_length,
358
).input_ids
359
360
if self.with_prior_preservation:
361
class_path, class_prompt = self.class_images_path[index % self.num_class_images]
362
class_image = Image.open(class_path)
363
if not class_image.mode == "RGB":
364
class_image = class_image.convert("RGB")
365
example["class_images"] = self.image_transforms(class_image)
366
example["class_masks"], example["class_masked_images"] = generate_random_mask(example["class_images"])
367
example["class_prompt_ids"] = self.tokenizer(
368
class_prompt,
369
padding="max_length" if self.pad_tokens else "do_not_pad",
370
truncation=True,
371
max_length=self.tokenizer.model_max_length,
372
).input_ids
373
374
return example
375
376
377
class PromptDataset(Dataset):
378
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
379
380
def __init__(self, prompt, num_samples):
381
self.prompt = prompt
382
self.num_samples = num_samples
383
384
def __len__(self):
385
return self.num_samples
386
387
def __getitem__(self, index):
388
example = {}
389
example["prompt"] = self.prompt
390
example["index"] = index
391
return example
392
393
394
class LatentsDataset(Dataset):
395
def __init__(self, latents_cache, text_encoder_cache):
396
self.latents_cache = latents_cache
397
self.text_encoder_cache = text_encoder_cache
398
399
def __len__(self):
400
return len(self.latents_cache)
401
402
def __getitem__(self, index):
403
return self.latents_cache[index], self.text_encoder_cache[index]
404
405
406
class AverageMeter:
407
def __init__(self, name=None):
408
self.name = name
409
self.reset()
410
411
def reset(self):
412
self.sum = self.count = self.avg = 0
413
414
def update(self, val, n=1):
415
self.sum += val * n
416
self.count += n
417
self.avg = self.sum / self.count
418
419
420
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
421
if token is None:
422
token = HfFolder.get_token()
423
if organization is None:
424
username = whoami(token)["name"]
425
return f"{username}/{model_id}"
426
else:
427
return f"{organization}/{model_id}"
428
429
430
def main(args):
431
logging_dir = Path(args.output_dir, "0", args.logging_dir)
432
433
accelerator = Accelerator(
434
gradient_accumulation_steps=args.gradient_accumulation_steps,
435
mixed_precision=args.mixed_precision,
436
log_with="tensorboard",
437
logging_dir=logging_dir,
438
)
439
440
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
441
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
442
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
443
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
444
raise ValueError(
445
"Gradient accumulation is not supported when training the text encoder in distributed training. "
446
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
447
)
448
449
if args.seed is not None:
450
set_seed(args.seed)
451
452
if args.concepts_list is None:
453
args.concepts_list = [
454
{
455
"instance_prompt": args.instance_prompt,
456
"class_prompt": args.class_prompt,
457
"instance_data_dir": args.instance_data_dir,
458
"class_data_dir": args.class_data_dir
459
}
460
]
461
else:
462
with open(args.concepts_list, "r") as f:
463
args.concepts_list = json.load(f)
464
465
if args.with_prior_preservation:
466
pipeline = None
467
for concept in args.concepts_list:
468
class_images_dir = Path(concept["class_data_dir"])
469
class_images_dir.mkdir(parents=True, exist_ok=True)
470
cur_class_images = len(list(class_images_dir.iterdir()))
471
472
if cur_class_images < args.num_class_images:
473
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
474
if pipeline is None:
475
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
476
args.pretrained_model_name_or_path,
477
vae=AutoencoderKL.from_pretrained(
478
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
479
revision=None if args.pretrained_vae_name_or_path else args.revision,
480
torch_dtype=torch_dtype
481
),
482
torch_dtype=torch_dtype,
483
safety_checker=None,
484
revision=args.revision
485
)
486
pipeline.set_progress_bar_config(disable=True)
487
pipeline.to(accelerator.device)
488
489
num_new_images = args.num_class_images - cur_class_images
490
logger.info(f"Number of class images to sample: {num_new_images}.")
491
492
sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
493
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
494
495
sample_dataloader = accelerator.prepare(sample_dataloader)
496
497
inp_img = Image.new("RGB", (512, 512), color=(0, 0, 0))
498
inp_mask = Image.new("L", (512, 512), color=255)
499
500
with torch.autocast("cuda"),torch.inference_mode():
501
for example in tqdm(
502
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
503
):
504
images = pipeline(
505
prompt=example["prompt"],
506
image=inp_img,
507
mask_image=inp_mask,
508
num_inference_steps=args.save_infer_steps
509
).images
510
511
for i, image in enumerate(images):
512
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
513
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
514
image.save(image_filename)
515
516
del pipeline
517
if torch.cuda.is_available():
518
torch.cuda.empty_cache()
519
520
# Load the tokenizer
521
if args.tokenizer_name:
522
tokenizer = CLIPTokenizer.from_pretrained(
523
args.tokenizer_name,
524
revision=args.revision,
525
)
526
elif args.pretrained_model_name_or_path:
527
tokenizer = CLIPTokenizer.from_pretrained(
528
args.pretrained_model_name_or_path,
529
subfolder="tokenizer",
530
revision=args.revision,
531
)
532
533
# Load models and create wrapper for stable diffusion
534
text_encoder = CLIPTextModel.from_pretrained(
535
args.pretrained_model_name_or_path,
536
subfolder="text_encoder",
537
revision=args.revision,
538
)
539
vae = AutoencoderKL.from_pretrained(
540
args.pretrained_model_name_or_path,
541
subfolder="vae",
542
revision=args.revision,
543
)
544
unet = UNet2DConditionModel.from_pretrained(
545
args.pretrained_model_name_or_path,
546
subfolder="unet",
547
revision=args.revision,
548
torch_dtype=torch.float32
549
)
550
551
vae.requires_grad_(False)
552
if not args.train_text_encoder:
553
text_encoder.requires_grad_(False)
554
555
if args.gradient_checkpointing:
556
unet.enable_gradient_checkpointing()
557
if args.train_text_encoder:
558
text_encoder.gradient_checkpointing_enable()
559
560
if args.scale_lr:
561
args.learning_rate = (
562
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
563
)
564
565
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
566
if args.use_8bit_adam:
567
try:
568
import bitsandbytes as bnb
569
except ImportError:
570
raise ImportError(
571
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
572
)
573
574
optimizer_class = bnb.optim.AdamW8bit
575
else:
576
optimizer_class = torch.optim.AdamW
577
578
params_to_optimize = (
579
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
580
)
581
optimizer = optimizer_class(
582
params_to_optimize,
583
lr=args.learning_rate,
584
betas=(args.adam_beta1, args.adam_beta2),
585
weight_decay=args.adam_weight_decay,
586
eps=args.adam_epsilon,
587
)
588
589
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
590
591
train_dataset = DreamBoothDataset(
592
concepts_list=args.concepts_list,
593
tokenizer=tokenizer,
594
with_prior_preservation=args.with_prior_preservation,
595
size=args.resolution,
596
center_crop=args.center_crop,
597
num_class_images=args.num_class_images,
598
pad_tokens=args.pad_tokens,
599
hflip=args.hflip
600
)
601
602
def collate_fn(examples):
603
input_ids = [example["instance_prompt_ids"] for example in examples]
604
pixel_values = [example["instance_images"] for example in examples]
605
mask_values = [example["instance_masks"] for example in examples]
606
masked_image_values = [example["instance_masked_images"] for example in examples]
607
608
# Concat class and instance examples for prior preservation.
609
# We do this to avoid doing two forward passes.
610
if args.with_prior_preservation:
611
input_ids += [example["class_prompt_ids"] for example in examples]
612
pixel_values += [example["class_images"] for example in examples]
613
mask_values += [example["class_masks"] for example in examples]
614
masked_image_values += [example["class_masked_images"] for example in examples]
615
616
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
617
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
618
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
619
620
input_ids = tokenizer.pad(
621
{"input_ids": input_ids},
622
padding=True,
623
return_tensors="pt",
624
).input_ids
625
626
batch = {
627
"input_ids": input_ids,
628
"pixel_values": pixel_values,
629
"mask_values": mask_values,
630
"masked_image_values": masked_image_values
631
}
632
return batch
633
634
train_dataloader = torch.utils.data.DataLoader(
635
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=8
636
)
637
638
weight_dtype = torch.float32
639
if args.mixed_precision == "fp16":
640
weight_dtype = torch.float16
641
elif args.mixed_precision == "bf16":
642
weight_dtype = torch.bfloat16
643
644
# Move text_encode and vae to gpu.
645
# For mixed precision training we cast the text_encoder and vae weights to half-precision
646
# as these models are only used for inference, keeping weights in full precision is not required.
647
vae.to(accelerator.device, dtype=weight_dtype)
648
if not args.train_text_encoder:
649
text_encoder.to(accelerator.device, dtype=weight_dtype)
650
651
if not args.not_cache_latents:
652
latents_cache = []
653
text_encoder_cache = []
654
for batch in tqdm(train_dataloader, desc="Caching latents"):
655
with torch.no_grad():
656
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
657
batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
658
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
659
if args.train_text_encoder:
660
text_encoder_cache.append(batch["input_ids"])
661
else:
662
text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
663
train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
664
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
665
666
del vae
667
if not args.train_text_encoder:
668
del text_encoder
669
if torch.cuda.is_available():
670
torch.cuda.empty_cache()
671
672
# Scheduler and math around the number of training steps.
673
overrode_max_train_steps = False
674
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
675
if args.max_train_steps is None:
676
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
677
overrode_max_train_steps = True
678
679
lr_scheduler = get_scheduler(
680
args.lr_scheduler,
681
optimizer=optimizer,
682
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
683
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
684
)
685
686
if args.train_text_encoder:
687
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
688
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
689
)
690
else:
691
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
692
unet, optimizer, train_dataloader, lr_scheduler
693
)
694
695
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
696
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
697
if overrode_max_train_steps:
698
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
699
# Afterwards we recalculate our number of training epochs
700
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
701
702
# We need to initialize the trackers we use, and also store our configuration.
703
# The trackers initializes automatically on the main process.
704
if accelerator.is_main_process:
705
accelerator.init_trackers("dreambooth")
706
707
# Train!
708
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
709
710
logger.info("***** Running training *****")
711
logger.info(f" Num examples = {len(train_dataset)}")
712
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
713
logger.info(f" Num Epochs = {args.num_train_epochs}")
714
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
715
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
716
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
717
logger.info(f" Total optimization steps = {args.max_train_steps}")
718
719
def save_weights(step):
720
# Create the pipeline using using the trained modules and save it.
721
if accelerator.is_main_process:
722
if args.train_text_encoder:
723
text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
724
else:
725
text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
726
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
727
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
728
args.pretrained_model_name_or_path,
729
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).to(torch.float16),
730
text_encoder=text_enc_model.to(torch.float16),
731
vae=AutoencoderKL.from_pretrained(
732
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
733
subfolder=None if args.pretrained_vae_name_or_path else "vae",
734
revision=None if args.pretrained_vae_name_or_path else args.revision,
735
),
736
safety_checker=None,
737
scheduler=scheduler,
738
torch_dtype=torch.float16,
739
revision=args.revision,
740
)
741
save_dir = os.path.join(args.output_dir, f"{step}")
742
pipeline.save_pretrained(save_dir)
743
with open(os.path.join(save_dir, "args.json"), "w") as f:
744
json.dump(args.__dict__, f, indent=2)
745
746
shutil.copy("train_inpainting_dreambooth.py", save_dir)
747
748
pipeline = pipeline.to(accelerator.device)
749
pipeline.set_progress_bar_config(disable=True)
750
for idx, concept in enumerate(args.concepts_list):
751
g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)
752
sample_dir = os.path.join(save_dir, "samples", str(idx))
753
os.makedirs(sample_dir, exist_ok=True)
754
inp_img = Image.new("RGB", (512, 512), color=(0, 0, 0))
755
inp_mask = Image.new("L", (512, 512), color=255)
756
with torch.inference_mode():
757
for i in tqdm(range(args.n_save_sample), desc="Generating samples"):
758
images = pipeline(
759
prompt=concept["instance_prompt"],
760
image=inp_img,
761
mask_image=inp_mask,
762
negative_prompt=args.save_sample_negative_prompt,
763
guidance_scale=args.save_guidance_scale,
764
num_inference_steps=args.save_infer_steps,
765
generator=g_cuda
766
).images
767
images[0].save(os.path.join(sample_dir, f"{i}.png"))
768
del pipeline
769
if torch.cuda.is_available():
770
torch.cuda.empty_cache()
771
print(f"[*] Weights saved at {save_dir}")
772
unet.to(torch.float32)
773
text_enc_model.to(torch.float32)
774
775
# Only show the progress bar once on each machine.
776
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
777
progress_bar.set_description("Steps")
778
global_step = 0
779
loss_avg = AverageMeter()
780
text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
781
for epoch in range(args.num_train_epochs):
782
unet.train()
783
if args.train_text_encoder:
784
text_encoder.train()
785
random.shuffle(train_dataset.class_images_path)
786
for step, batch in enumerate(train_dataloader):
787
with accelerator.accumulate(unet):
788
# Convert images to latent space
789
with torch.no_grad():
790
if not args.not_cache_latents:
791
latent_dist = batch[0][0]
792
else:
793
latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
794
masked_latent_dist = vae.encode(batch["masked_image_values"].to(dtype=weight_dtype)).latent_dist
795
latents = latent_dist.sample() * 0.18215
796
masked_image_latents = masked_latent_dist.sample() * 0.18215
797
mask = F.interpolate(batch["mask_values"], scale_factor=1 / 8)
798
799
# Sample noise that we'll add to the latents
800
noise = torch.randn_like(latents)
801
bsz = latents.shape[0]
802
# Sample a random timestep for each image
803
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
804
timesteps = timesteps.long()
805
806
# Add noise to the latents according to the noise magnitude at each timestep
807
# (this is the forward diffusion process)
808
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
809
810
# Get the text embedding for conditioning
811
with text_enc_context:
812
if not args.not_cache_latents:
813
if args.train_text_encoder:
814
encoder_hidden_states = text_encoder(batch[0][1])[0]
815
else:
816
encoder_hidden_states = batch[0][1]
817
else:
818
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
819
820
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
821
# Predict the noise residual
822
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
823
824
if args.with_prior_preservation:
825
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
826
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
827
noise, noise_prior = torch.chunk(noise, 2, dim=0)
828
829
# Compute instance loss
830
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
831
832
# Compute prior loss
833
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
834
835
# Add the prior loss to the instance loss.
836
loss = loss + args.prior_loss_weight * prior_loss
837
else:
838
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
839
840
accelerator.backward(loss)
841
# if accelerator.sync_gradients:
842
# params_to_clip = (
843
# itertools.chain(unet.parameters(), text_encoder.parameters())
844
# if args.train_text_encoder
845
# else unet.parameters()
846
# )
847
# accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
848
optimizer.step()
849
lr_scheduler.step()
850
optimizer.zero_grad(set_to_none=True)
851
loss_avg.update(loss.detach_(), bsz)
852
853
if not global_step % args.log_interval:
854
logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
855
progress_bar.set_postfix(**logs)
856
accelerator.log(logs, step=global_step)
857
858
if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:
859
save_weights(global_step)
860
861
progress_bar.update(1)
862
global_step += 1
863
864
if global_step >= args.max_train_steps:
865
break
866
867
accelerator.wait_for_everyone()
868
869
save_weights(global_step)
870
871
accelerator.end_training()
872
873
874
if __name__ == "__main__":
875
args = parse_args()
876
main(args)
877
878