Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/textual_inversion/textual_inversion.py
1448 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
16
import argparse
17
import logging
18
import math
19
import os
20
import random
21
import warnings
22
from pathlib import Path
23
from typing import Optional
24
25
import numpy as np
26
import PIL
27
import torch
28
import torch.nn.functional as F
29
import torch.utils.checkpoint
30
import transformers
31
from accelerate import Accelerator
32
from accelerate.logging import get_logger
33
from accelerate.utils import ProjectConfiguration, set_seed
34
from huggingface_hub import HfFolder, Repository, create_repo, whoami
35
36
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
37
from packaging import version
38
from PIL import Image
39
from torch.utils.data import Dataset
40
from torchvision import transforms
41
from tqdm.auto import tqdm
42
from transformers import CLIPTextModel, CLIPTokenizer
43
44
import diffusers
45
from diffusers import (
46
AutoencoderKL,
47
DDPMScheduler,
48
DiffusionPipeline,
49
DPMSolverMultistepScheduler,
50
StableDiffusionPipeline,
51
UNet2DConditionModel,
52
)
53
from diffusers.optimization import get_scheduler
54
from diffusers.utils import check_min_version, is_wandb_available
55
from diffusers.utils.import_utils import is_xformers_available
56
57
58
if is_wandb_available():
59
import wandb
60
61
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
62
PIL_INTERPOLATION = {
63
"linear": PIL.Image.Resampling.BILINEAR,
64
"bilinear": PIL.Image.Resampling.BILINEAR,
65
"bicubic": PIL.Image.Resampling.BICUBIC,
66
"lanczos": PIL.Image.Resampling.LANCZOS,
67
"nearest": PIL.Image.Resampling.NEAREST,
68
}
69
else:
70
PIL_INTERPOLATION = {
71
"linear": PIL.Image.LINEAR,
72
"bilinear": PIL.Image.BILINEAR,
73
"bicubic": PIL.Image.BICUBIC,
74
"lanczos": PIL.Image.LANCZOS,
75
"nearest": PIL.Image.NEAREST,
76
}
77
# ------------------------------------------------------------------------------
78
79
80
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
81
check_min_version("0.15.0.dev0")
82
83
logger = get_logger(__name__)
84
85
86
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
87
logger.info(
88
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
89
f" {args.validation_prompt}."
90
)
91
# create pipeline (note: unet and vae are loaded again in float32)
92
pipeline = DiffusionPipeline.from_pretrained(
93
args.pretrained_model_name_or_path,
94
text_encoder=accelerator.unwrap_model(text_encoder),
95
tokenizer=tokenizer,
96
unet=unet,
97
vae=vae,
98
revision=args.revision,
99
torch_dtype=weight_dtype,
100
)
101
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
102
pipeline = pipeline.to(accelerator.device)
103
pipeline.set_progress_bar_config(disable=True)
104
105
# run inference
106
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
107
images = []
108
for _ in range(args.num_validation_images):
109
with torch.autocast("cuda"):
110
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
111
images.append(image)
112
113
for tracker in accelerator.trackers:
114
if tracker.name == "tensorboard":
115
np_images = np.stack([np.asarray(img) for img in images])
116
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
117
if tracker.name == "wandb":
118
tracker.log(
119
{
120
"validation": [
121
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
122
]
123
}
124
)
125
126
del pipeline
127
torch.cuda.empty_cache()
128
129
130
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
131
logger.info("Saving embeddings")
132
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
133
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
134
torch.save(learned_embeds_dict, save_path)
135
136
137
def parse_args():
138
parser = argparse.ArgumentParser(description="Simple example of a training script.")
139
parser.add_argument(
140
"--save_steps",
141
type=int,
142
default=500,
143
help="Save learned_embeds.bin every X updates steps.",
144
)
145
parser.add_argument(
146
"--only_save_embeds",
147
action="store_true",
148
default=False,
149
help="Save only the embeddings for the new concept.",
150
)
151
parser.add_argument(
152
"--pretrained_model_name_or_path",
153
type=str,
154
default=None,
155
required=True,
156
help="Path to pretrained model or model identifier from huggingface.co/models.",
157
)
158
parser.add_argument(
159
"--revision",
160
type=str,
161
default=None,
162
required=False,
163
help="Revision of pretrained model identifier from huggingface.co/models.",
164
)
165
parser.add_argument(
166
"--tokenizer_name",
167
type=str,
168
default=None,
169
help="Pretrained tokenizer name or path if not the same as model_name",
170
)
171
parser.add_argument(
172
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
173
)
174
parser.add_argument(
175
"--placeholder_token",
176
type=str,
177
default=None,
178
required=True,
179
help="A token to use as a placeholder for the concept.",
180
)
181
parser.add_argument(
182
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
183
)
184
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
185
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
186
parser.add_argument(
187
"--output_dir",
188
type=str,
189
default="text-inversion-model",
190
help="The output directory where the model predictions and checkpoints will be written.",
191
)
192
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
193
parser.add_argument(
194
"--resolution",
195
type=int,
196
default=512,
197
help=(
198
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
199
" resolution"
200
),
201
)
202
parser.add_argument(
203
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
204
)
205
parser.add_argument(
206
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
207
)
208
parser.add_argument("--num_train_epochs", type=int, default=100)
209
parser.add_argument(
210
"--max_train_steps",
211
type=int,
212
default=5000,
213
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
214
)
215
parser.add_argument(
216
"--gradient_accumulation_steps",
217
type=int,
218
default=1,
219
help="Number of updates steps to accumulate before performing a backward/update pass.",
220
)
221
parser.add_argument(
222
"--gradient_checkpointing",
223
action="store_true",
224
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
225
)
226
parser.add_argument(
227
"--learning_rate",
228
type=float,
229
default=1e-4,
230
help="Initial learning rate (after the potential warmup period) to use.",
231
)
232
parser.add_argument(
233
"--scale_lr",
234
action="store_true",
235
default=False,
236
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
237
)
238
parser.add_argument(
239
"--lr_scheduler",
240
type=str,
241
default="constant",
242
help=(
243
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
244
' "constant", "constant_with_warmup"]'
245
),
246
)
247
parser.add_argument(
248
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
249
)
250
parser.add_argument(
251
"--dataloader_num_workers",
252
type=int,
253
default=0,
254
help=(
255
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
256
),
257
)
258
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
259
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
260
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
261
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
262
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
263
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
264
parser.add_argument(
265
"--hub_model_id",
266
type=str,
267
default=None,
268
help="The name of the repository to keep in sync with the local `output_dir`.",
269
)
270
parser.add_argument(
271
"--logging_dir",
272
type=str,
273
default="logs",
274
help=(
275
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
276
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
277
),
278
)
279
parser.add_argument(
280
"--mixed_precision",
281
type=str,
282
default="no",
283
choices=["no", "fp16", "bf16"],
284
help=(
285
"Whether to use mixed precision. Choose"
286
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
287
"and an Nvidia Ampere GPU."
288
),
289
)
290
parser.add_argument(
291
"--allow_tf32",
292
action="store_true",
293
help=(
294
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
295
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
296
),
297
)
298
parser.add_argument(
299
"--report_to",
300
type=str,
301
default="tensorboard",
302
help=(
303
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
304
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
305
),
306
)
307
parser.add_argument(
308
"--validation_prompt",
309
type=str,
310
default=None,
311
help="A prompt that is used during validation to verify that the model is learning.",
312
)
313
parser.add_argument(
314
"--num_validation_images",
315
type=int,
316
default=4,
317
help="Number of images that should be generated during validation with `validation_prompt`.",
318
)
319
parser.add_argument(
320
"--validation_steps",
321
type=int,
322
default=100,
323
help=(
324
"Run validation every X steps. Validation consists of running the prompt"
325
" `args.validation_prompt` multiple times: `args.num_validation_images`"
326
" and logging the images."
327
),
328
)
329
parser.add_argument(
330
"--validation_epochs",
331
type=int,
332
default=None,
333
help=(
334
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
335
" `args.validation_prompt` multiple times: `args.num_validation_images`"
336
" and logging the images."
337
),
338
)
339
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
340
parser.add_argument(
341
"--checkpointing_steps",
342
type=int,
343
default=500,
344
help=(
345
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
346
" training using `--resume_from_checkpoint`."
347
),
348
)
349
parser.add_argument(
350
"--checkpoints_total_limit",
351
type=int,
352
default=None,
353
help=(
354
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
355
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
356
" for more docs"
357
),
358
)
359
parser.add_argument(
360
"--resume_from_checkpoint",
361
type=str,
362
default=None,
363
help=(
364
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
365
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
366
),
367
)
368
parser.add_argument(
369
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
370
)
371
372
args = parser.parse_args()
373
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
374
if env_local_rank != -1 and env_local_rank != args.local_rank:
375
args.local_rank = env_local_rank
376
377
if args.train_data_dir is None:
378
raise ValueError("You must specify a train data directory.")
379
380
return args
381
382
383
imagenet_templates_small = [
384
"a photo of a {}",
385
"a rendering of a {}",
386
"a cropped photo of the {}",
387
"the photo of a {}",
388
"a photo of a clean {}",
389
"a photo of a dirty {}",
390
"a dark photo of the {}",
391
"a photo of my {}",
392
"a photo of the cool {}",
393
"a close-up photo of a {}",
394
"a bright photo of the {}",
395
"a cropped photo of a {}",
396
"a photo of the {}",
397
"a good photo of the {}",
398
"a photo of one {}",
399
"a close-up photo of the {}",
400
"a rendition of the {}",
401
"a photo of the clean {}",
402
"a rendition of a {}",
403
"a photo of a nice {}",
404
"a good photo of a {}",
405
"a photo of the nice {}",
406
"a photo of the small {}",
407
"a photo of the weird {}",
408
"a photo of the large {}",
409
"a photo of a cool {}",
410
"a photo of a small {}",
411
]
412
413
imagenet_style_templates_small = [
414
"a painting in the style of {}",
415
"a rendering in the style of {}",
416
"a cropped painting in the style of {}",
417
"the painting in the style of {}",
418
"a clean painting in the style of {}",
419
"a dirty painting in the style of {}",
420
"a dark painting in the style of {}",
421
"a picture in the style of {}",
422
"a cool painting in the style of {}",
423
"a close-up painting in the style of {}",
424
"a bright painting in the style of {}",
425
"a cropped painting in the style of {}",
426
"a good painting in the style of {}",
427
"a close-up painting in the style of {}",
428
"a rendition in the style of {}",
429
"a nice painting in the style of {}",
430
"a small painting in the style of {}",
431
"a weird painting in the style of {}",
432
"a large painting in the style of {}",
433
]
434
435
436
class TextualInversionDataset(Dataset):
437
def __init__(
438
self,
439
data_root,
440
tokenizer,
441
learnable_property="object", # [object, style]
442
size=512,
443
repeats=100,
444
interpolation="bicubic",
445
flip_p=0.5,
446
set="train",
447
placeholder_token="*",
448
center_crop=False,
449
):
450
self.data_root = data_root
451
self.tokenizer = tokenizer
452
self.learnable_property = learnable_property
453
self.size = size
454
self.placeholder_token = placeholder_token
455
self.center_crop = center_crop
456
self.flip_p = flip_p
457
458
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
459
460
self.num_images = len(self.image_paths)
461
self._length = self.num_images
462
463
if set == "train":
464
self._length = self.num_images * repeats
465
466
self.interpolation = {
467
"linear": PIL_INTERPOLATION["linear"],
468
"bilinear": PIL_INTERPOLATION["bilinear"],
469
"bicubic": PIL_INTERPOLATION["bicubic"],
470
"lanczos": PIL_INTERPOLATION["lanczos"],
471
}[interpolation]
472
473
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
474
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
475
476
def __len__(self):
477
return self._length
478
479
def __getitem__(self, i):
480
example = {}
481
image = Image.open(self.image_paths[i % self.num_images])
482
483
if not image.mode == "RGB":
484
image = image.convert("RGB")
485
486
placeholder_string = self.placeholder_token
487
text = random.choice(self.templates).format(placeholder_string)
488
489
example["input_ids"] = self.tokenizer(
490
text,
491
padding="max_length",
492
truncation=True,
493
max_length=self.tokenizer.model_max_length,
494
return_tensors="pt",
495
).input_ids[0]
496
497
# default to score-sde preprocessing
498
img = np.array(image).astype(np.uint8)
499
500
if self.center_crop:
501
crop = min(img.shape[0], img.shape[1])
502
(
503
h,
504
w,
505
) = (
506
img.shape[0],
507
img.shape[1],
508
)
509
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
510
511
image = Image.fromarray(img)
512
image = image.resize((self.size, self.size), resample=self.interpolation)
513
514
image = self.flip_transform(image)
515
image = np.array(image).astype(np.uint8)
516
image = (image / 127.5 - 1.0).astype(np.float32)
517
518
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
519
return example
520
521
522
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
523
if token is None:
524
token = HfFolder.get_token()
525
if organization is None:
526
username = whoami(token)["name"]
527
return f"{username}/{model_id}"
528
else:
529
return f"{organization}/{model_id}"
530
531
532
def main():
533
args = parse_args()
534
logging_dir = os.path.join(args.output_dir, args.logging_dir)
535
536
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
537
538
accelerator = Accelerator(
539
gradient_accumulation_steps=args.gradient_accumulation_steps,
540
mixed_precision=args.mixed_precision,
541
log_with=args.report_to,
542
logging_dir=logging_dir,
543
project_config=accelerator_project_config,
544
)
545
546
if args.report_to == "wandb":
547
if not is_wandb_available():
548
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
549
550
# Make one log on every process with the configuration for debugging.
551
logging.basicConfig(
552
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
553
datefmt="%m/%d/%Y %H:%M:%S",
554
level=logging.INFO,
555
)
556
logger.info(accelerator.state, main_process_only=False)
557
if accelerator.is_local_main_process:
558
transformers.utils.logging.set_verbosity_warning()
559
diffusers.utils.logging.set_verbosity_info()
560
else:
561
transformers.utils.logging.set_verbosity_error()
562
diffusers.utils.logging.set_verbosity_error()
563
564
# If passed along, set the training seed now.
565
if args.seed is not None:
566
set_seed(args.seed)
567
568
# Handle the repository creation
569
if accelerator.is_main_process:
570
if args.push_to_hub:
571
if args.hub_model_id is None:
572
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
573
else:
574
repo_name = args.hub_model_id
575
create_repo(repo_name, exist_ok=True, token=args.hub_token)
576
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
577
578
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
579
if "step_*" not in gitignore:
580
gitignore.write("step_*\n")
581
if "epoch_*" not in gitignore:
582
gitignore.write("epoch_*\n")
583
elif args.output_dir is not None:
584
os.makedirs(args.output_dir, exist_ok=True)
585
586
# Load tokenizer
587
if args.tokenizer_name:
588
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
589
elif args.pretrained_model_name_or_path:
590
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
591
592
# Load scheduler and models
593
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
594
text_encoder = CLIPTextModel.from_pretrained(
595
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
596
)
597
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
598
unet = UNet2DConditionModel.from_pretrained(
599
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
600
)
601
602
# Add the placeholder token in tokenizer
603
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
604
if num_added_tokens == 0:
605
raise ValueError(
606
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
607
" `placeholder_token` that is not already in the tokenizer."
608
)
609
610
# Convert the initializer_token, placeholder_token to ids
611
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
612
# Check if initializer_token is a single token or a sequence of tokens
613
if len(token_ids) > 1:
614
raise ValueError("The initializer token must be a single token.")
615
616
initializer_token_id = token_ids[0]
617
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
618
619
# Resize the token embeddings as we are adding new special tokens to the tokenizer
620
text_encoder.resize_token_embeddings(len(tokenizer))
621
622
# Initialise the newly added placeholder token with the embeddings of the initializer token
623
token_embeds = text_encoder.get_input_embeddings().weight.data
624
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
625
626
# Freeze vae and unet
627
vae.requires_grad_(False)
628
unet.requires_grad_(False)
629
# Freeze all parameters except for the token embeddings in text encoder
630
text_encoder.text_model.encoder.requires_grad_(False)
631
text_encoder.text_model.final_layer_norm.requires_grad_(False)
632
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
633
634
if args.gradient_checkpointing:
635
# Keep unet in train mode if we are using gradient checkpointing to save memory.
636
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
637
unet.train()
638
text_encoder.gradient_checkpointing_enable()
639
unet.enable_gradient_checkpointing()
640
641
if args.enable_xformers_memory_efficient_attention:
642
if is_xformers_available():
643
import xformers
644
645
xformers_version = version.parse(xformers.__version__)
646
if xformers_version == version.parse("0.0.16"):
647
logger.warn(
648
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
649
)
650
unet.enable_xformers_memory_efficient_attention()
651
else:
652
raise ValueError("xformers is not available. Make sure it is installed correctly")
653
654
# Enable TF32 for faster training on Ampere GPUs,
655
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
656
if args.allow_tf32:
657
torch.backends.cuda.matmul.allow_tf32 = True
658
659
if args.scale_lr:
660
args.learning_rate = (
661
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
662
)
663
664
# Initialize the optimizer
665
optimizer = torch.optim.AdamW(
666
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
667
lr=args.learning_rate,
668
betas=(args.adam_beta1, args.adam_beta2),
669
weight_decay=args.adam_weight_decay,
670
eps=args.adam_epsilon,
671
)
672
673
# Dataset and DataLoaders creation:
674
train_dataset = TextualInversionDataset(
675
data_root=args.train_data_dir,
676
tokenizer=tokenizer,
677
size=args.resolution,
678
placeholder_token=args.placeholder_token,
679
repeats=args.repeats,
680
learnable_property=args.learnable_property,
681
center_crop=args.center_crop,
682
set="train",
683
)
684
train_dataloader = torch.utils.data.DataLoader(
685
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
686
)
687
if args.validation_epochs is not None:
688
warnings.warn(
689
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
690
" Deprecated validation_epochs in favor of `validation_steps`"
691
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
692
FutureWarning,
693
stacklevel=2,
694
)
695
args.validation_steps = args.validation_epochs * len(train_dataset)
696
697
# Scheduler and math around the number of training steps.
698
overrode_max_train_steps = False
699
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700
if args.max_train_steps is None:
701
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
702
overrode_max_train_steps = True
703
704
lr_scheduler = get_scheduler(
705
args.lr_scheduler,
706
optimizer=optimizer,
707
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
708
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
709
)
710
711
# Prepare everything with our `accelerator`.
712
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
713
text_encoder, optimizer, train_dataloader, lr_scheduler
714
)
715
716
# For mixed precision training we cast the unet and vae weights to half-precision
717
# as these models are only used for inference, keeping weights in full precision is not required.
718
weight_dtype = torch.float32
719
if accelerator.mixed_precision == "fp16":
720
weight_dtype = torch.float16
721
elif accelerator.mixed_precision == "bf16":
722
weight_dtype = torch.bfloat16
723
724
# Move vae and unet to device and cast to weight_dtype
725
unet.to(accelerator.device, dtype=weight_dtype)
726
vae.to(accelerator.device, dtype=weight_dtype)
727
728
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
729
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
730
if overrode_max_train_steps:
731
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
732
# Afterwards we recalculate our number of training epochs
733
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
734
735
# We need to initialize the trackers we use, and also store our configuration.
736
# The trackers initializes automatically on the main process.
737
if accelerator.is_main_process:
738
accelerator.init_trackers("textual_inversion", config=vars(args))
739
740
# Train!
741
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
742
743
logger.info("***** Running training *****")
744
logger.info(f" Num examples = {len(train_dataset)}")
745
logger.info(f" Num Epochs = {args.num_train_epochs}")
746
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
747
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
748
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
749
logger.info(f" Total optimization steps = {args.max_train_steps}")
750
global_step = 0
751
first_epoch = 0
752
# Potentially load in the weights and states from a previous save
753
if args.resume_from_checkpoint:
754
if args.resume_from_checkpoint != "latest":
755
path = os.path.basename(args.resume_from_checkpoint)
756
else:
757
# Get the most recent checkpoint
758
dirs = os.listdir(args.output_dir)
759
dirs = [d for d in dirs if d.startswith("checkpoint")]
760
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
761
path = dirs[-1] if len(dirs) > 0 else None
762
763
if path is None:
764
accelerator.print(
765
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
766
)
767
args.resume_from_checkpoint = None
768
else:
769
accelerator.print(f"Resuming from checkpoint {path}")
770
accelerator.load_state(os.path.join(args.output_dir, path))
771
global_step = int(path.split("-")[1])
772
773
resume_global_step = global_step * args.gradient_accumulation_steps
774
first_epoch = global_step // num_update_steps_per_epoch
775
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
776
777
# Only show the progress bar once on each machine.
778
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
779
progress_bar.set_description("Steps")
780
781
# keep original embeddings as reference
782
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
783
784
for epoch in range(first_epoch, args.num_train_epochs):
785
text_encoder.train()
786
for step, batch in enumerate(train_dataloader):
787
# Skip steps until we reach the resumed step
788
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
789
if step % args.gradient_accumulation_steps == 0:
790
progress_bar.update(1)
791
continue
792
793
with accelerator.accumulate(text_encoder):
794
# Convert images to latent space
795
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
796
latents = latents * vae.config.scaling_factor
797
798
# Sample noise that we'll add to the latents
799
noise = torch.randn_like(latents)
800
bsz = latents.shape[0]
801
# Sample a random timestep for each image
802
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
803
timesteps = timesteps.long()
804
805
# Add noise to the latents according to the noise magnitude at each timestep
806
# (this is the forward diffusion process)
807
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
808
809
# Get the text embedding for conditioning
810
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
811
812
# Predict the noise residual
813
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
814
815
# Get the target for loss depending on the prediction type
816
if noise_scheduler.config.prediction_type == "epsilon":
817
target = noise
818
elif noise_scheduler.config.prediction_type == "v_prediction":
819
target = noise_scheduler.get_velocity(latents, noise, timesteps)
820
else:
821
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
822
823
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
824
825
accelerator.backward(loss)
826
827
optimizer.step()
828
lr_scheduler.step()
829
optimizer.zero_grad()
830
831
# Let's make sure we don't update any embedding weights besides the newly added token
832
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
833
with torch.no_grad():
834
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
835
index_no_updates
836
] = orig_embeds_params[index_no_updates]
837
838
# Checks if the accelerator has performed an optimization step behind the scenes
839
if accelerator.sync_gradients:
840
progress_bar.update(1)
841
global_step += 1
842
if global_step % args.save_steps == 0:
843
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
844
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
845
846
if accelerator.is_main_process:
847
if global_step % args.checkpointing_steps == 0:
848
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
849
accelerator.save_state(save_path)
850
logger.info(f"Saved state to {save_path}")
851
852
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
853
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
854
855
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
856
progress_bar.set_postfix(**logs)
857
accelerator.log(logs, step=global_step)
858
859
if global_step >= args.max_train_steps:
860
break
861
# Create the pipeline using using the trained modules and save it.
862
accelerator.wait_for_everyone()
863
if accelerator.is_main_process:
864
if args.push_to_hub and args.only_save_embeds:
865
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
866
save_full_model = True
867
else:
868
save_full_model = not args.only_save_embeds
869
if save_full_model:
870
pipeline = StableDiffusionPipeline.from_pretrained(
871
args.pretrained_model_name_or_path,
872
text_encoder=accelerator.unwrap_model(text_encoder),
873
vae=vae,
874
unet=unet,
875
tokenizer=tokenizer,
876
)
877
pipeline.save_pretrained(args.output_dir)
878
# Save the newly trained embeddings
879
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
880
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
881
882
if args.push_to_hub:
883
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
884
885
accelerator.end_training()
886
887
888
if __name__ == "__main__":
889
main()
890
891