Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ShivamShrirao
GitHub Repository: ShivamShrirao/diffusers
Path: blob/main/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
1980 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
16
import argparse
17
import logging
18
import math
19
import os
20
import random
21
from pathlib import Path
22
from typing import Optional
23
24
import datasets
25
import numpy as np
26
import PIL
27
import torch
28
import torch.nn.functional as F
29
import torch.utils.checkpoint
30
import transformers
31
from accelerate import Accelerator
32
from accelerate.logging import get_logger
33
from accelerate.utils import ProjectConfiguration, set_seed
34
from huggingface_hub import HfFolder, Repository, create_repo, whoami
35
from onnxruntime.training.ortmodule import ORTModule
36
37
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
38
from packaging import version
39
from PIL import Image
40
from torch.utils.data import Dataset
41
from torchvision import transforms
42
from tqdm.auto import tqdm
43
from transformers import CLIPTextModel, CLIPTokenizer
44
45
import diffusers
46
from diffusers import (
47
AutoencoderKL,
48
DDPMScheduler,
49
DiffusionPipeline,
50
DPMSolverMultistepScheduler,
51
StableDiffusionPipeline,
52
UNet2DConditionModel,
53
)
54
from diffusers.optimization import get_scheduler
55
from diffusers.utils import check_min_version, is_wandb_available
56
from diffusers.utils.import_utils import is_xformers_available
57
58
59
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
60
PIL_INTERPOLATION = {
61
"linear": PIL.Image.Resampling.BILINEAR,
62
"bilinear": PIL.Image.Resampling.BILINEAR,
63
"bicubic": PIL.Image.Resampling.BICUBIC,
64
"lanczos": PIL.Image.Resampling.LANCZOS,
65
"nearest": PIL.Image.Resampling.NEAREST,
66
}
67
else:
68
PIL_INTERPOLATION = {
69
"linear": PIL.Image.LINEAR,
70
"bilinear": PIL.Image.BILINEAR,
71
"bicubic": PIL.Image.BICUBIC,
72
"lanczos": PIL.Image.LANCZOS,
73
"nearest": PIL.Image.NEAREST,
74
}
75
# ------------------------------------------------------------------------------
76
77
78
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
79
check_min_version("0.13.0.dev0")
80
81
logger = get_logger(__name__)
82
83
84
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
85
logger.info("Saving embeddings")
86
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
87
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
88
torch.save(learned_embeds_dict, save_path)
89
90
91
def parse_args():
92
parser = argparse.ArgumentParser(description="Simple example of a training script.")
93
parser.add_argument(
94
"--save_steps",
95
type=int,
96
default=500,
97
help="Save learned_embeds.bin every X updates steps.",
98
)
99
parser.add_argument(
100
"--only_save_embeds",
101
action="store_true",
102
default=False,
103
help="Save only the embeddings for the new concept.",
104
)
105
parser.add_argument(
106
"--pretrained_model_name_or_path",
107
type=str,
108
default=None,
109
required=True,
110
help="Path to pretrained model or model identifier from huggingface.co/models.",
111
)
112
parser.add_argument(
113
"--revision",
114
type=str,
115
default=None,
116
required=False,
117
help="Revision of pretrained model identifier from huggingface.co/models.",
118
)
119
parser.add_argument(
120
"--tokenizer_name",
121
type=str,
122
default=None,
123
help="Pretrained tokenizer name or path if not the same as model_name",
124
)
125
parser.add_argument(
126
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
127
)
128
parser.add_argument(
129
"--placeholder_token",
130
type=str,
131
default=None,
132
required=True,
133
help="A token to use as a placeholder for the concept.",
134
)
135
parser.add_argument(
136
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
137
)
138
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
139
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
140
parser.add_argument(
141
"--output_dir",
142
type=str,
143
default="text-inversion-model",
144
help="The output directory where the model predictions and checkpoints will be written.",
145
)
146
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
147
parser.add_argument(
148
"--resolution",
149
type=int,
150
default=512,
151
help=(
152
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
153
" resolution"
154
),
155
)
156
parser.add_argument(
157
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
158
)
159
parser.add_argument(
160
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
161
)
162
parser.add_argument("--num_train_epochs", type=int, default=100)
163
parser.add_argument(
164
"--max_train_steps",
165
type=int,
166
default=5000,
167
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
168
)
169
parser.add_argument(
170
"--gradient_accumulation_steps",
171
type=int,
172
default=1,
173
help="Number of updates steps to accumulate before performing a backward/update pass.",
174
)
175
parser.add_argument(
176
"--gradient_checkpointing",
177
action="store_true",
178
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
179
)
180
parser.add_argument(
181
"--learning_rate",
182
type=float,
183
default=1e-4,
184
help="Initial learning rate (after the potential warmup period) to use.",
185
)
186
parser.add_argument(
187
"--scale_lr",
188
action="store_true",
189
default=False,
190
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
191
)
192
parser.add_argument(
193
"--lr_scheduler",
194
type=str,
195
default="constant",
196
help=(
197
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
198
' "constant", "constant_with_warmup"]'
199
),
200
)
201
parser.add_argument(
202
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
203
)
204
parser.add_argument(
205
"--dataloader_num_workers",
206
type=int,
207
default=0,
208
help=(
209
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
210
),
211
)
212
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
213
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
214
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
215
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
216
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
217
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
218
parser.add_argument(
219
"--hub_model_id",
220
type=str,
221
default=None,
222
help="The name of the repository to keep in sync with the local `output_dir`.",
223
)
224
parser.add_argument(
225
"--logging_dir",
226
type=str,
227
default="logs",
228
help=(
229
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
230
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
231
),
232
)
233
parser.add_argument(
234
"--mixed_precision",
235
type=str,
236
default="no",
237
choices=["no", "fp16", "bf16"],
238
help=(
239
"Whether to use mixed precision. Choose"
240
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
241
"and an Nvidia Ampere GPU."
242
),
243
)
244
parser.add_argument(
245
"--allow_tf32",
246
action="store_true",
247
help=(
248
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
249
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
250
),
251
)
252
parser.add_argument(
253
"--report_to",
254
type=str,
255
default="tensorboard",
256
help=(
257
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
258
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
259
),
260
)
261
parser.add_argument(
262
"--validation_prompt",
263
type=str,
264
default=None,
265
help="A prompt that is used during validation to verify that the model is learning.",
266
)
267
parser.add_argument(
268
"--num_validation_images",
269
type=int,
270
default=4,
271
help="Number of images that should be generated during validation with `validation_prompt`.",
272
)
273
parser.add_argument(
274
"--validation_epochs",
275
type=int,
276
default=50,
277
help=(
278
"Run validation every X epochs. Validation consists of running the prompt"
279
" `args.validation_prompt` multiple times: `args.num_validation_images`"
280
" and logging the images."
281
),
282
)
283
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
284
parser.add_argument(
285
"--checkpointing_steps",
286
type=int,
287
default=500,
288
help=(
289
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
290
" training using `--resume_from_checkpoint`."
291
),
292
)
293
parser.add_argument(
294
"--checkpoints_total_limit",
295
type=int,
296
default=None,
297
help=(
298
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
299
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
300
" for more docs"
301
),
302
)
303
parser.add_argument(
304
"--resume_from_checkpoint",
305
type=str,
306
default=None,
307
help=(
308
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
309
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
310
),
311
)
312
parser.add_argument(
313
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
314
)
315
316
args = parser.parse_args()
317
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
318
if env_local_rank != -1 and env_local_rank != args.local_rank:
319
args.local_rank = env_local_rank
320
321
if args.train_data_dir is None:
322
raise ValueError("You must specify a train data directory.")
323
324
return args
325
326
327
imagenet_templates_small = [
328
"a photo of a {}",
329
"a rendering of a {}",
330
"a cropped photo of the {}",
331
"the photo of a {}",
332
"a photo of a clean {}",
333
"a photo of a dirty {}",
334
"a dark photo of the {}",
335
"a photo of my {}",
336
"a photo of the cool {}",
337
"a close-up photo of a {}",
338
"a bright photo of the {}",
339
"a cropped photo of a {}",
340
"a photo of the {}",
341
"a good photo of the {}",
342
"a photo of one {}",
343
"a close-up photo of the {}",
344
"a rendition of the {}",
345
"a photo of the clean {}",
346
"a rendition of a {}",
347
"a photo of a nice {}",
348
"a good photo of a {}",
349
"a photo of the nice {}",
350
"a photo of the small {}",
351
"a photo of the weird {}",
352
"a photo of the large {}",
353
"a photo of a cool {}",
354
"a photo of a small {}",
355
]
356
357
imagenet_style_templates_small = [
358
"a painting in the style of {}",
359
"a rendering in the style of {}",
360
"a cropped painting in the style of {}",
361
"the painting in the style of {}",
362
"a clean painting in the style of {}",
363
"a dirty painting in the style of {}",
364
"a dark painting in the style of {}",
365
"a picture in the style of {}",
366
"a cool painting in the style of {}",
367
"a close-up painting in the style of {}",
368
"a bright painting in the style of {}",
369
"a cropped painting in the style of {}",
370
"a good painting in the style of {}",
371
"a close-up painting in the style of {}",
372
"a rendition in the style of {}",
373
"a nice painting in the style of {}",
374
"a small painting in the style of {}",
375
"a weird painting in the style of {}",
376
"a large painting in the style of {}",
377
]
378
379
380
class TextualInversionDataset(Dataset):
381
def __init__(
382
self,
383
data_root,
384
tokenizer,
385
learnable_property="object", # [object, style]
386
size=512,
387
repeats=100,
388
interpolation="bicubic",
389
flip_p=0.5,
390
set="train",
391
placeholder_token="*",
392
center_crop=False,
393
):
394
self.data_root = data_root
395
self.tokenizer = tokenizer
396
self.learnable_property = learnable_property
397
self.size = size
398
self.placeholder_token = placeholder_token
399
self.center_crop = center_crop
400
self.flip_p = flip_p
401
402
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
403
404
self.num_images = len(self.image_paths)
405
self._length = self.num_images
406
407
if set == "train":
408
self._length = self.num_images * repeats
409
410
self.interpolation = {
411
"linear": PIL_INTERPOLATION["linear"],
412
"bilinear": PIL_INTERPOLATION["bilinear"],
413
"bicubic": PIL_INTERPOLATION["bicubic"],
414
"lanczos": PIL_INTERPOLATION["lanczos"],
415
}[interpolation]
416
417
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
418
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
419
420
def __len__(self):
421
return self._length
422
423
def __getitem__(self, i):
424
example = {}
425
image = Image.open(self.image_paths[i % self.num_images])
426
427
if not image.mode == "RGB":
428
image = image.convert("RGB")
429
430
placeholder_string = self.placeholder_token
431
text = random.choice(self.templates).format(placeholder_string)
432
433
example["input_ids"] = self.tokenizer(
434
text,
435
padding="max_length",
436
truncation=True,
437
max_length=self.tokenizer.model_max_length,
438
return_tensors="pt",
439
).input_ids[0]
440
441
# default to score-sde preprocessing
442
img = np.array(image).astype(np.uint8)
443
444
if self.center_crop:
445
crop = min(img.shape[0], img.shape[1])
446
(
447
h,
448
w,
449
) = (
450
img.shape[0],
451
img.shape[1],
452
)
453
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
454
455
image = Image.fromarray(img)
456
image = image.resize((self.size, self.size), resample=self.interpolation)
457
458
image = self.flip_transform(image)
459
image = np.array(image).astype(np.uint8)
460
image = (image / 127.5 - 1.0).astype(np.float32)
461
462
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
463
return example
464
465
466
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
467
if token is None:
468
token = HfFolder.get_token()
469
if organization is None:
470
username = whoami(token)["name"]
471
return f"{username}/{model_id}"
472
else:
473
return f"{organization}/{model_id}"
474
475
476
def main():
477
args = parse_args()
478
logging_dir = os.path.join(args.output_dir, args.logging_dir)
479
480
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
481
482
accelerator = Accelerator(
483
gradient_accumulation_steps=args.gradient_accumulation_steps,
484
mixed_precision=args.mixed_precision,
485
log_with=args.report_to,
486
logging_dir=logging_dir,
487
project_config=accelerator_project_config,
488
)
489
490
if args.report_to == "wandb":
491
if not is_wandb_available():
492
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
493
import wandb
494
495
# Make one log on every process with the configuration for debugging.
496
logging.basicConfig(
497
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
498
datefmt="%m/%d/%Y %H:%M:%S",
499
level=logging.INFO,
500
)
501
logger.info(accelerator.state, main_process_only=False)
502
if accelerator.is_local_main_process:
503
datasets.utils.logging.set_verbosity_warning()
504
transformers.utils.logging.set_verbosity_warning()
505
diffusers.utils.logging.set_verbosity_info()
506
else:
507
datasets.utils.logging.set_verbosity_error()
508
transformers.utils.logging.set_verbosity_error()
509
diffusers.utils.logging.set_verbosity_error()
510
511
# If passed along, set the training seed now.
512
if args.seed is not None:
513
set_seed(args.seed)
514
515
# Handle the repository creation
516
if accelerator.is_main_process:
517
if args.push_to_hub:
518
if args.hub_model_id is None:
519
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
520
else:
521
repo_name = args.hub_model_id
522
create_repo(repo_name, exist_ok=True, token=args.hub_token)
523
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
524
525
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
526
if "step_*" not in gitignore:
527
gitignore.write("step_*\n")
528
if "epoch_*" not in gitignore:
529
gitignore.write("epoch_*\n")
530
elif args.output_dir is not None:
531
os.makedirs(args.output_dir, exist_ok=True)
532
533
# Load tokenizer
534
if args.tokenizer_name:
535
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
536
elif args.pretrained_model_name_or_path:
537
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
538
539
# Load scheduler and models
540
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
541
text_encoder = CLIPTextModel.from_pretrained(
542
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
543
)
544
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
545
unet = UNet2DConditionModel.from_pretrained(
546
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
547
)
548
549
# Add the placeholder token in tokenizer
550
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
551
if num_added_tokens == 0:
552
raise ValueError(
553
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
554
" `placeholder_token` that is not already in the tokenizer."
555
)
556
557
# Convert the initializer_token, placeholder_token to ids
558
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
559
# Check if initializer_token is a single token or a sequence of tokens
560
if len(token_ids) > 1:
561
raise ValueError("The initializer token must be a single token.")
562
563
initializer_token_id = token_ids[0]
564
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
565
566
# Resize the token embeddings as we are adding new special tokens to the tokenizer
567
text_encoder.resize_token_embeddings(len(tokenizer))
568
569
# Initialise the newly added placeholder token with the embeddings of the initializer token
570
token_embeds = text_encoder.get_input_embeddings().weight.data
571
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
572
573
# Freeze vae and unet
574
vae.requires_grad_(False)
575
unet.requires_grad_(False)
576
# Freeze all parameters except for the token embeddings in text encoder
577
text_encoder.text_model.encoder.requires_grad_(False)
578
text_encoder.text_model.final_layer_norm.requires_grad_(False)
579
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
580
581
if args.gradient_checkpointing:
582
# Keep unet in train mode if we are using gradient checkpointing to save memory.
583
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
584
unet.train()
585
text_encoder.gradient_checkpointing_enable()
586
unet.enable_gradient_checkpointing()
587
588
if args.enable_xformers_memory_efficient_attention:
589
if is_xformers_available():
590
unet.enable_xformers_memory_efficient_attention()
591
else:
592
raise ValueError("xformers is not available. Make sure it is installed correctly")
593
594
# Enable TF32 for faster training on Ampere GPUs,
595
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
596
if args.allow_tf32:
597
torch.backends.cuda.matmul.allow_tf32 = True
598
599
if args.scale_lr:
600
args.learning_rate = (
601
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
602
)
603
604
# Initialize the optimizer
605
optimizer = torch.optim.AdamW(
606
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
607
lr=args.learning_rate,
608
betas=(args.adam_beta1, args.adam_beta2),
609
weight_decay=args.adam_weight_decay,
610
eps=args.adam_epsilon,
611
)
612
613
# Dataset and DataLoaders creation:
614
train_dataset = TextualInversionDataset(
615
data_root=args.train_data_dir,
616
tokenizer=tokenizer,
617
size=args.resolution,
618
placeholder_token=args.placeholder_token,
619
repeats=args.repeats,
620
learnable_property=args.learnable_property,
621
center_crop=args.center_crop,
622
set="train",
623
)
624
train_dataloader = torch.utils.data.DataLoader(
625
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
626
)
627
628
# Scheduler and math around the number of training steps.
629
overrode_max_train_steps = False
630
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
631
if args.max_train_steps is None:
632
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
633
overrode_max_train_steps = True
634
635
lr_scheduler = get_scheduler(
636
args.lr_scheduler,
637
optimizer=optimizer,
638
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
639
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
640
)
641
642
# Prepare everything with our `accelerator`.
643
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
644
text_encoder, optimizer, train_dataloader, lr_scheduler
645
)
646
647
text_encoder = ORTModule(text_encoder)
648
649
# For mixed precision training we cast the unet and vae weights to half-precision
650
# as these models are only used for inference, keeping weights in full precision is not required.
651
weight_dtype = torch.float32
652
if accelerator.mixed_precision == "fp16":
653
weight_dtype = torch.float16
654
elif accelerator.mixed_precision == "bf16":
655
weight_dtype = torch.bfloat16
656
657
# Move vae and unet to device and cast to weight_dtype
658
unet.to(accelerator.device, dtype=weight_dtype)
659
vae.to(accelerator.device, dtype=weight_dtype)
660
661
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
662
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
663
if overrode_max_train_steps:
664
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
665
# Afterwards we recalculate our number of training epochs
666
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
667
668
# We need to initialize the trackers we use, and also store our configuration.
669
# The trackers initializes automatically on the main process.
670
if accelerator.is_main_process:
671
accelerator.init_trackers("textual_inversion", config=vars(args))
672
673
# Train!
674
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
675
676
logger.info("***** Running training *****")
677
logger.info(f" Num examples = {len(train_dataset)}")
678
logger.info(f" Num Epochs = {args.num_train_epochs}")
679
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
680
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
681
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
682
logger.info(f" Total optimization steps = {args.max_train_steps}")
683
global_step = 0
684
first_epoch = 0
685
686
# Potentially load in the weights and states from a previous save
687
if args.resume_from_checkpoint:
688
if args.resume_from_checkpoint != "latest":
689
path = os.path.basename(args.resume_from_checkpoint)
690
else:
691
# Get the most recent checkpoint
692
dirs = os.listdir(args.output_dir)
693
dirs = [d for d in dirs if d.startswith("checkpoint")]
694
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
695
path = dirs[-1] if len(dirs) > 0 else None
696
697
if path is None:
698
accelerator.print(
699
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
700
)
701
args.resume_from_checkpoint = None
702
else:
703
accelerator.print(f"Resuming from checkpoint {path}")
704
accelerator.load_state(os.path.join(args.output_dir, path))
705
global_step = int(path.split("-")[1])
706
707
resume_global_step = global_step * args.gradient_accumulation_steps
708
first_epoch = global_step // num_update_steps_per_epoch
709
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
710
711
# Only show the progress bar once on each machine.
712
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
713
progress_bar.set_description("Steps")
714
715
# keep original embeddings as reference
716
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
717
718
for epoch in range(first_epoch, args.num_train_epochs):
719
text_encoder.train()
720
for step, batch in enumerate(train_dataloader):
721
# Skip steps until we reach the resumed step
722
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
723
if step % args.gradient_accumulation_steps == 0:
724
progress_bar.update(1)
725
continue
726
727
with accelerator.accumulate(text_encoder):
728
# Convert images to latent space
729
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
730
latents = latents * vae.config.scaling_factor
731
732
# Sample noise that we'll add to the latents
733
noise = torch.randn_like(latents)
734
bsz = latents.shape[0]
735
# Sample a random timestep for each image
736
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
737
timesteps = timesteps.long()
738
739
# Add noise to the latents according to the noise magnitude at each timestep
740
# (this is the forward diffusion process)
741
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
742
743
# Get the text embedding for conditioning
744
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
745
746
# Predict the noise residual
747
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
748
749
# Get the target for loss depending on the prediction type
750
if noise_scheduler.config.prediction_type == "epsilon":
751
target = noise
752
elif noise_scheduler.config.prediction_type == "v_prediction":
753
target = noise_scheduler.get_velocity(latents, noise, timesteps)
754
else:
755
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
756
757
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
758
759
accelerator.backward(loss)
760
761
optimizer.step()
762
lr_scheduler.step()
763
optimizer.zero_grad()
764
765
# Let's make sure we don't update any embedding weights besides the newly added token
766
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
767
with torch.no_grad():
768
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
769
index_no_updates
770
] = orig_embeds_params[index_no_updates]
771
772
# Checks if the accelerator has performed an optimization step behind the scenes
773
if accelerator.sync_gradients:
774
progress_bar.update(1)
775
global_step += 1
776
if global_step % args.save_steps == 0:
777
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
778
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
779
780
if global_step % args.checkpointing_steps == 0:
781
if accelerator.is_main_process:
782
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
783
accelerator.save_state(save_path)
784
logger.info(f"Saved state to {save_path}")
785
786
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
787
progress_bar.set_postfix(**logs)
788
accelerator.log(logs, step=global_step)
789
790
if global_step >= args.max_train_steps:
791
break
792
793
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
794
logger.info(
795
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
796
f" {args.validation_prompt}."
797
)
798
# create pipeline (note: unet and vae are loaded again in float32)
799
pipeline = DiffusionPipeline.from_pretrained(
800
args.pretrained_model_name_or_path,
801
text_encoder=accelerator.unwrap_model(text_encoder),
802
revision=args.revision,
803
)
804
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
805
pipeline = pipeline.to(accelerator.device)
806
pipeline.set_progress_bar_config(disable=True)
807
808
# run inference
809
generator = (
810
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
811
)
812
prompt = args.num_validation_images * [args.validation_prompt]
813
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
814
815
for tracker in accelerator.trackers:
816
if tracker.name == "tensorboard":
817
np_images = np.stack([np.asarray(img) for img in images])
818
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
819
if tracker.name == "wandb":
820
tracker.log(
821
{
822
"validation": [
823
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
824
for i, image in enumerate(images)
825
]
826
}
827
)
828
829
del pipeline
830
torch.cuda.empty_cache()
831
832
# Create the pipeline using using the trained modules and save it.
833
accelerator.wait_for_everyone()
834
if accelerator.is_main_process:
835
if args.push_to_hub and args.only_save_embeds:
836
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
837
save_full_model = True
838
else:
839
save_full_model = not args.only_save_embeds
840
if save_full_model:
841
pipeline = StableDiffusionPipeline.from_pretrained(
842
args.pretrained_model_name_or_path,
843
text_encoder=accelerator.unwrap_model(text_encoder),
844
vae=vae,
845
unet=unet,
846
tokenizer=tokenizer,
847
)
848
pipeline.save_pretrained(args.output_dir)
849
# Save the newly trained embeddings
850
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
851
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
852
853
if args.push_to_hub:
854
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
855
856
accelerator.end_training()
857
858
859
if __name__ == "__main__":
860
main()
861
862