Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ShivamShrirao
GitHub Repository: ShivamShrirao/diffusers
Path: blob/main/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
1979 views
1
import argparse
2
import hashlib
3
import itertools
4
import math
5
import os
6
import random
7
from pathlib import Path
8
from typing import Optional
9
10
import numpy as np
11
import torch
12
import torch.nn.functional as F
13
import torch.utils.checkpoint
14
from accelerate import Accelerator
15
from accelerate.logging import get_logger
16
from accelerate.utils import ProjectConfiguration, set_seed
17
from huggingface_hub import HfFolder, Repository, create_repo, whoami
18
from PIL import Image, ImageDraw
19
from torch.utils.data import Dataset
20
from torchvision import transforms
21
from tqdm.auto import tqdm
22
from transformers import CLIPTextModel, CLIPTokenizer
23
24
from diffusers import (
25
AutoencoderKL,
26
DDPMScheduler,
27
StableDiffusionInpaintPipeline,
28
StableDiffusionPipeline,
29
UNet2DConditionModel,
30
)
31
from diffusers.optimization import get_scheduler
32
from diffusers.utils import check_min_version
33
34
35
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
36
check_min_version("0.13.0.dev0")
37
38
logger = get_logger(__name__)
39
40
41
def prepare_mask_and_masked_image(image, mask):
42
image = np.array(image.convert("RGB"))
43
image = image[None].transpose(0, 3, 1, 2)
44
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
45
46
mask = np.array(mask.convert("L"))
47
mask = mask.astype(np.float32) / 255.0
48
mask = mask[None, None]
49
mask[mask < 0.5] = 0
50
mask[mask >= 0.5] = 1
51
mask = torch.from_numpy(mask)
52
53
masked_image = image * (mask < 0.5)
54
55
return mask, masked_image
56
57
58
# generate random masks
59
def random_mask(im_shape, ratio=1, mask_full_image=False):
60
mask = Image.new("L", im_shape, 0)
61
draw = ImageDraw.Draw(mask)
62
size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
63
# use this to always mask the whole image
64
if mask_full_image:
65
size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
66
limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
67
center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
68
draw_type = random.randint(0, 1)
69
if draw_type == 0 or mask_full_image:
70
draw.rectangle(
71
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
72
fill=255,
73
)
74
else:
75
draw.ellipse(
76
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
77
fill=255,
78
)
79
80
return mask
81
82
83
def parse_args():
84
parser = argparse.ArgumentParser(description="Simple example of a training script.")
85
parser.add_argument(
86
"--pretrained_model_name_or_path",
87
type=str,
88
default=None,
89
required=True,
90
help="Path to pretrained model or model identifier from huggingface.co/models.",
91
)
92
parser.add_argument(
93
"--tokenizer_name",
94
type=str,
95
default=None,
96
help="Pretrained tokenizer name or path if not the same as model_name",
97
)
98
parser.add_argument(
99
"--instance_data_dir",
100
type=str,
101
default=None,
102
required=True,
103
help="A folder containing the training data of instance images.",
104
)
105
parser.add_argument(
106
"--class_data_dir",
107
type=str,
108
default=None,
109
required=False,
110
help="A folder containing the training data of class images.",
111
)
112
parser.add_argument(
113
"--instance_prompt",
114
type=str,
115
default=None,
116
help="The prompt with identifier specifying the instance",
117
)
118
parser.add_argument(
119
"--class_prompt",
120
type=str,
121
default=None,
122
help="The prompt to specify images in the same class as provided instance images.",
123
)
124
parser.add_argument(
125
"--with_prior_preservation",
126
default=False,
127
action="store_true",
128
help="Flag to add prior preservation loss.",
129
)
130
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
131
parser.add_argument(
132
"--num_class_images",
133
type=int,
134
default=100,
135
help=(
136
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
137
" sampled with class_prompt."
138
),
139
)
140
parser.add_argument(
141
"--output_dir",
142
type=str,
143
default="text-inversion-model",
144
help="The output directory where the model predictions and checkpoints will be written.",
145
)
146
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
147
parser.add_argument(
148
"--resolution",
149
type=int,
150
default=512,
151
help=(
152
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
153
" resolution"
154
),
155
)
156
parser.add_argument(
157
"--center_crop",
158
default=False,
159
action="store_true",
160
help=(
161
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
162
" cropped. The images will be resized to the resolution first before cropping."
163
),
164
)
165
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
166
parser.add_argument(
167
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
168
)
169
parser.add_argument(
170
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
171
)
172
parser.add_argument("--num_train_epochs", type=int, default=1)
173
parser.add_argument(
174
"--max_train_steps",
175
type=int,
176
default=None,
177
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
178
)
179
parser.add_argument(
180
"--gradient_accumulation_steps",
181
type=int,
182
default=1,
183
help="Number of updates steps to accumulate before performing a backward/update pass.",
184
)
185
parser.add_argument(
186
"--gradient_checkpointing",
187
action="store_true",
188
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
189
)
190
parser.add_argument(
191
"--learning_rate",
192
type=float,
193
default=5e-6,
194
help="Initial learning rate (after the potential warmup period) to use.",
195
)
196
parser.add_argument(
197
"--scale_lr",
198
action="store_true",
199
default=False,
200
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
201
)
202
parser.add_argument(
203
"--lr_scheduler",
204
type=str,
205
default="constant",
206
help=(
207
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
208
' "constant", "constant_with_warmup"]'
209
),
210
)
211
parser.add_argument(
212
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
213
)
214
parser.add_argument(
215
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
216
)
217
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
218
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
219
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
220
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
221
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
222
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
223
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
224
parser.add_argument(
225
"--hub_model_id",
226
type=str,
227
default=None,
228
help="The name of the repository to keep in sync with the local `output_dir`.",
229
)
230
parser.add_argument(
231
"--logging_dir",
232
type=str,
233
default="logs",
234
help=(
235
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
236
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
237
),
238
)
239
parser.add_argument(
240
"--mixed_precision",
241
type=str,
242
default="no",
243
choices=["no", "fp16", "bf16"],
244
help=(
245
"Whether to use mixed precision. Choose"
246
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
247
"and an Nvidia Ampere GPU."
248
),
249
)
250
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
251
parser.add_argument(
252
"--checkpointing_steps",
253
type=int,
254
default=500,
255
help=(
256
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
257
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
258
" using `--resume_from_checkpoint`."
259
),
260
)
261
parser.add_argument(
262
"--checkpoints_total_limit",
263
type=int,
264
default=None,
265
help=(
266
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
267
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
268
" for more docs"
269
),
270
)
271
parser.add_argument(
272
"--resume_from_checkpoint",
273
type=str,
274
default=None,
275
help=(
276
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
277
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
278
),
279
)
280
281
args = parser.parse_args()
282
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
283
if env_local_rank != -1 and env_local_rank != args.local_rank:
284
args.local_rank = env_local_rank
285
286
if args.instance_data_dir is None:
287
raise ValueError("You must specify a train data directory.")
288
289
if args.with_prior_preservation:
290
if args.class_data_dir is None:
291
raise ValueError("You must specify a data directory for class images.")
292
if args.class_prompt is None:
293
raise ValueError("You must specify prompt for class images.")
294
295
return args
296
297
298
class DreamBoothDataset(Dataset):
299
"""
300
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
301
It pre-processes the images and the tokenizes prompts.
302
"""
303
304
def __init__(
305
self,
306
instance_data_root,
307
instance_prompt,
308
tokenizer,
309
class_data_root=None,
310
class_prompt=None,
311
size=512,
312
center_crop=False,
313
):
314
self.size = size
315
self.center_crop = center_crop
316
self.tokenizer = tokenizer
317
318
self.instance_data_root = Path(instance_data_root)
319
if not self.instance_data_root.exists():
320
raise ValueError("Instance images root doesn't exists.")
321
322
self.instance_images_path = list(Path(instance_data_root).iterdir())
323
self.num_instance_images = len(self.instance_images_path)
324
self.instance_prompt = instance_prompt
325
self._length = self.num_instance_images
326
327
if class_data_root is not None:
328
self.class_data_root = Path(class_data_root)
329
self.class_data_root.mkdir(parents=True, exist_ok=True)
330
self.class_images_path = list(self.class_data_root.iterdir())
331
self.num_class_images = len(self.class_images_path)
332
self._length = max(self.num_class_images, self.num_instance_images)
333
self.class_prompt = class_prompt
334
else:
335
self.class_data_root = None
336
337
self.image_transforms_resize_and_crop = transforms.Compose(
338
[
339
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
340
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
341
]
342
)
343
344
self.image_transforms = transforms.Compose(
345
[
346
transforms.ToTensor(),
347
transforms.Normalize([0.5], [0.5]),
348
]
349
)
350
351
def __len__(self):
352
return self._length
353
354
def __getitem__(self, index):
355
example = {}
356
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
357
if not instance_image.mode == "RGB":
358
instance_image = instance_image.convert("RGB")
359
instance_image = self.image_transforms_resize_and_crop(instance_image)
360
361
example["PIL_images"] = instance_image
362
example["instance_images"] = self.image_transforms(instance_image)
363
364
example["instance_prompt_ids"] = self.tokenizer(
365
self.instance_prompt,
366
padding="do_not_pad",
367
truncation=True,
368
max_length=self.tokenizer.model_max_length,
369
).input_ids
370
371
if self.class_data_root:
372
class_image = Image.open(self.class_images_path[index % self.num_class_images])
373
if not class_image.mode == "RGB":
374
class_image = class_image.convert("RGB")
375
class_image = self.image_transforms_resize_and_crop(class_image)
376
example["class_images"] = self.image_transforms(class_image)
377
example["class_PIL_images"] = class_image
378
example["class_prompt_ids"] = self.tokenizer(
379
self.class_prompt,
380
padding="do_not_pad",
381
truncation=True,
382
max_length=self.tokenizer.model_max_length,
383
).input_ids
384
385
return example
386
387
388
class PromptDataset(Dataset):
389
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
390
391
def __init__(self, prompt, num_samples):
392
self.prompt = prompt
393
self.num_samples = num_samples
394
395
def __len__(self):
396
return self.num_samples
397
398
def __getitem__(self, index):
399
example = {}
400
example["prompt"] = self.prompt
401
example["index"] = index
402
return example
403
404
405
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
406
if token is None:
407
token = HfFolder.get_token()
408
if organization is None:
409
username = whoami(token)["name"]
410
return f"{username}/{model_id}"
411
else:
412
return f"{organization}/{model_id}"
413
414
415
def main():
416
args = parse_args()
417
logging_dir = Path(args.output_dir, args.logging_dir)
418
419
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
420
421
accelerator = Accelerator(
422
gradient_accumulation_steps=args.gradient_accumulation_steps,
423
mixed_precision=args.mixed_precision,
424
log_with="tensorboard",
425
logging_dir=logging_dir,
426
accelerator_project_config=accelerator_project_config,
427
)
428
429
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
430
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
431
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
432
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
433
raise ValueError(
434
"Gradient accumulation is not supported when training the text encoder in distributed training. "
435
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
436
)
437
438
if args.seed is not None:
439
set_seed(args.seed)
440
441
if args.with_prior_preservation:
442
class_images_dir = Path(args.class_data_dir)
443
if not class_images_dir.exists():
444
class_images_dir.mkdir(parents=True)
445
cur_class_images = len(list(class_images_dir.iterdir()))
446
447
if cur_class_images < args.num_class_images:
448
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
449
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
450
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
451
)
452
pipeline.set_progress_bar_config(disable=True)
453
454
num_new_images = args.num_class_images - cur_class_images
455
logger.info(f"Number of class images to sample: {num_new_images}.")
456
457
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
458
sample_dataloader = torch.utils.data.DataLoader(
459
sample_dataset, batch_size=args.sample_batch_size, num_workers=1
460
)
461
462
sample_dataloader = accelerator.prepare(sample_dataloader)
463
pipeline.to(accelerator.device)
464
transform_to_pil = transforms.ToPILImage()
465
for example in tqdm(
466
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
467
):
468
bsz = len(example["prompt"])
469
fake_images = torch.rand((3, args.resolution, args.resolution))
470
transform_to_pil = transforms.ToPILImage()
471
fake_pil_images = transform_to_pil(fake_images)
472
473
fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
474
475
images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
476
477
for i, image in enumerate(images):
478
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
479
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
480
image.save(image_filename)
481
482
del pipeline
483
if torch.cuda.is_available():
484
torch.cuda.empty_cache()
485
486
# Handle the repository creation
487
if accelerator.is_main_process:
488
if args.push_to_hub:
489
if args.hub_model_id is None:
490
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
491
else:
492
repo_name = args.hub_model_id
493
create_repo(repo_name, exist_ok=True, token=args.hub_token)
494
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
495
496
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
497
if "step_*" not in gitignore:
498
gitignore.write("step_*\n")
499
if "epoch_*" not in gitignore:
500
gitignore.write("epoch_*\n")
501
elif args.output_dir is not None:
502
os.makedirs(args.output_dir, exist_ok=True)
503
504
# Load the tokenizer
505
if args.tokenizer_name:
506
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
507
elif args.pretrained_model_name_or_path:
508
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
509
510
# Load models and create wrapper for stable diffusion
511
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
512
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
513
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
514
515
vae.requires_grad_(False)
516
if not args.train_text_encoder:
517
text_encoder.requires_grad_(False)
518
519
if args.gradient_checkpointing:
520
unet.enable_gradient_checkpointing()
521
if args.train_text_encoder:
522
text_encoder.gradient_checkpointing_enable()
523
524
if args.scale_lr:
525
args.learning_rate = (
526
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
527
)
528
529
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
530
if args.use_8bit_adam:
531
try:
532
import bitsandbytes as bnb
533
except ImportError:
534
raise ImportError(
535
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
536
)
537
538
optimizer_class = bnb.optim.AdamW8bit
539
else:
540
optimizer_class = torch.optim.AdamW
541
542
params_to_optimize = (
543
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
544
)
545
optimizer = optimizer_class(
546
params_to_optimize,
547
lr=args.learning_rate,
548
betas=(args.adam_beta1, args.adam_beta2),
549
weight_decay=args.adam_weight_decay,
550
eps=args.adam_epsilon,
551
)
552
553
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
554
555
train_dataset = DreamBoothDataset(
556
instance_data_root=args.instance_data_dir,
557
instance_prompt=args.instance_prompt,
558
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
559
class_prompt=args.class_prompt,
560
tokenizer=tokenizer,
561
size=args.resolution,
562
center_crop=args.center_crop,
563
)
564
565
def collate_fn(examples):
566
input_ids = [example["instance_prompt_ids"] for example in examples]
567
pixel_values = [example["instance_images"] for example in examples]
568
569
# Concat class and instance examples for prior preservation.
570
# We do this to avoid doing two forward passes.
571
if args.with_prior_preservation:
572
input_ids += [example["class_prompt_ids"] for example in examples]
573
pixel_values += [example["class_images"] for example in examples]
574
pior_pil = [example["class_PIL_images"] for example in examples]
575
576
masks = []
577
masked_images = []
578
for example in examples:
579
pil_image = example["PIL_images"]
580
# generate a random mask
581
mask = random_mask(pil_image.size, 1, False)
582
# prepare mask and masked image
583
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
584
585
masks.append(mask)
586
masked_images.append(masked_image)
587
588
if args.with_prior_preservation:
589
for pil_image in pior_pil:
590
# generate a random mask
591
mask = random_mask(pil_image.size, 1, False)
592
# prepare mask and masked image
593
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
594
595
masks.append(mask)
596
masked_images.append(masked_image)
597
598
pixel_values = torch.stack(pixel_values)
599
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
600
601
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
602
masks = torch.stack(masks)
603
masked_images = torch.stack(masked_images)
604
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
605
return batch
606
607
train_dataloader = torch.utils.data.DataLoader(
608
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
609
)
610
611
# Scheduler and math around the number of training steps.
612
overrode_max_train_steps = False
613
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
614
if args.max_train_steps is None:
615
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
616
overrode_max_train_steps = True
617
618
lr_scheduler = get_scheduler(
619
args.lr_scheduler,
620
optimizer=optimizer,
621
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
622
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
623
)
624
625
if args.train_text_encoder:
626
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
627
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
628
)
629
else:
630
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
631
unet, optimizer, train_dataloader, lr_scheduler
632
)
633
accelerator.register_for_checkpointing(lr_scheduler)
634
635
weight_dtype = torch.float32
636
if args.mixed_precision == "fp16":
637
weight_dtype = torch.float16
638
elif args.mixed_precision == "bf16":
639
weight_dtype = torch.bfloat16
640
641
# Move text_encode and vae to gpu.
642
# For mixed precision training we cast the text_encoder and vae weights to half-precision
643
# as these models are only used for inference, keeping weights in full precision is not required.
644
vae.to(accelerator.device, dtype=weight_dtype)
645
if not args.train_text_encoder:
646
text_encoder.to(accelerator.device, dtype=weight_dtype)
647
648
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
649
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
650
if overrode_max_train_steps:
651
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
652
# Afterwards we recalculate our number of training epochs
653
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
654
655
# We need to initialize the trackers we use, and also store our configuration.
656
# The trackers initializes automatically on the main process.
657
if accelerator.is_main_process:
658
accelerator.init_trackers("dreambooth", config=vars(args))
659
660
# Train!
661
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
662
663
logger.info("***** Running training *****")
664
logger.info(f" Num examples = {len(train_dataset)}")
665
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
666
logger.info(f" Num Epochs = {args.num_train_epochs}")
667
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
668
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
669
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
670
logger.info(f" Total optimization steps = {args.max_train_steps}")
671
global_step = 0
672
first_epoch = 0
673
674
if args.resume_from_checkpoint:
675
if args.resume_from_checkpoint != "latest":
676
path = os.path.basename(args.resume_from_checkpoint)
677
else:
678
# Get the most recent checkpoint
679
dirs = os.listdir(args.output_dir)
680
dirs = [d for d in dirs if d.startswith("checkpoint")]
681
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
682
path = dirs[-1] if len(dirs) > 0 else None
683
684
if path is None:
685
accelerator.print(
686
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
687
)
688
args.resume_from_checkpoint = None
689
else:
690
accelerator.print(f"Resuming from checkpoint {path}")
691
accelerator.load_state(os.path.join(args.output_dir, path))
692
global_step = int(path.split("-")[1])
693
694
resume_global_step = global_step * args.gradient_accumulation_steps
695
first_epoch = global_step // num_update_steps_per_epoch
696
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
697
698
# Only show the progress bar once on each machine.
699
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
700
progress_bar.set_description("Steps")
701
702
for epoch in range(first_epoch, args.num_train_epochs):
703
unet.train()
704
for step, batch in enumerate(train_dataloader):
705
# Skip steps until we reach the resumed step
706
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
707
if step % args.gradient_accumulation_steps == 0:
708
progress_bar.update(1)
709
continue
710
711
with accelerator.accumulate(unet):
712
# Convert images to latent space
713
714
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
715
latents = latents * vae.config.scaling_factor
716
717
# Convert masked images to latent space
718
masked_latents = vae.encode(
719
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
720
).latent_dist.sample()
721
masked_latents = masked_latents * vae.config.scaling_factor
722
723
masks = batch["masks"]
724
# resize the mask to latents shape as we concatenate the mask to the latents
725
mask = torch.stack(
726
[
727
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
728
for mask in masks
729
]
730
)
731
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
732
733
# Sample noise that we'll add to the latents
734
noise = torch.randn_like(latents)
735
bsz = latents.shape[0]
736
# Sample a random timestep for each image
737
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
738
timesteps = timesteps.long()
739
740
# Add noise to the latents according to the noise magnitude at each timestep
741
# (this is the forward diffusion process)
742
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
743
744
# concatenate the noised latents with the mask and the masked latents
745
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
746
747
# Get the text embedding for conditioning
748
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
749
750
# Predict the noise residual
751
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
752
753
# Get the target for loss depending on the prediction type
754
if noise_scheduler.config.prediction_type == "epsilon":
755
target = noise
756
elif noise_scheduler.config.prediction_type == "v_prediction":
757
target = noise_scheduler.get_velocity(latents, noise, timesteps)
758
else:
759
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
760
761
if args.with_prior_preservation:
762
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
763
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
764
target, target_prior = torch.chunk(target, 2, dim=0)
765
766
# Compute instance loss
767
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
768
769
# Compute prior loss
770
prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
771
772
# Add the prior loss to the instance loss.
773
loss = loss + args.prior_loss_weight * prior_loss
774
else:
775
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
776
777
accelerator.backward(loss)
778
if accelerator.sync_gradients:
779
params_to_clip = (
780
itertools.chain(unet.parameters(), text_encoder.parameters())
781
if args.train_text_encoder
782
else unet.parameters()
783
)
784
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
785
optimizer.step()
786
lr_scheduler.step()
787
optimizer.zero_grad()
788
789
# Checks if the accelerator has performed an optimization step behind the scenes
790
if accelerator.sync_gradients:
791
progress_bar.update(1)
792
global_step += 1
793
794
if global_step % args.checkpointing_steps == 0:
795
if accelerator.is_main_process:
796
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
797
accelerator.save_state(save_path)
798
logger.info(f"Saved state to {save_path}")
799
800
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
801
progress_bar.set_postfix(**logs)
802
accelerator.log(logs, step=global_step)
803
804
if global_step >= args.max_train_steps:
805
break
806
807
accelerator.wait_for_everyone()
808
809
# Create the pipeline using using the trained modules and save it.
810
if accelerator.is_main_process:
811
pipeline = StableDiffusionPipeline.from_pretrained(
812
args.pretrained_model_name_or_path,
813
unet=accelerator.unwrap_model(unet),
814
text_encoder=accelerator.unwrap_model(text_encoder),
815
)
816
pipeline.save_pretrained(args.output_dir)
817
818
if args.push_to_hub:
819
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
820
821
accelerator.end_training()
822
823
824
if __name__ == "__main__":
825
main()
826
827