Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ShivamShrirao
GitHub Repository: ShivamShrirao/diffusers
Path: blob/main/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
1979 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
16
import argparse
17
import logging
18
import math
19
import os
20
import random
21
from pathlib import Path
22
from typing import Optional
23
24
import numpy as np
25
import PIL
26
import torch
27
import torch.nn.functional as F
28
import torch.utils.checkpoint
29
import transformers
30
from accelerate import Accelerator
31
from accelerate.logging import get_logger
32
from accelerate.utils import ProjectConfiguration, set_seed
33
from huggingface_hub import HfFolder, Repository, create_repo, whoami
34
from multi_token_clip import MultiTokenCLIPTokenizer
35
36
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
37
from packaging import version
38
from PIL import Image
39
from torch.utils.data import Dataset
40
from torchvision import transforms
41
from tqdm.auto import tqdm
42
from transformers import CLIPTextModel
43
44
import diffusers
45
from diffusers import (
46
AutoencoderKL,
47
DDPMScheduler,
48
DiffusionPipeline,
49
DPMSolverMultistepScheduler,
50
StableDiffusionPipeline,
51
UNet2DConditionModel,
52
)
53
from diffusers.optimization import get_scheduler
54
from diffusers.utils import check_min_version, is_wandb_available
55
from diffusers.utils.import_utils import is_xformers_available
56
57
58
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
59
PIL_INTERPOLATION = {
60
"linear": PIL.Image.Resampling.BILINEAR,
61
"bilinear": PIL.Image.Resampling.BILINEAR,
62
"bicubic": PIL.Image.Resampling.BICUBIC,
63
"lanczos": PIL.Image.Resampling.LANCZOS,
64
"nearest": PIL.Image.Resampling.NEAREST,
65
}
66
else:
67
PIL_INTERPOLATION = {
68
"linear": PIL.Image.LINEAR,
69
"bilinear": PIL.Image.BILINEAR,
70
"bicubic": PIL.Image.BICUBIC,
71
"lanczos": PIL.Image.LANCZOS,
72
"nearest": PIL.Image.NEAREST,
73
}
74
# ------------------------------------------------------------------------------
75
76
77
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
78
check_min_version("0.14.0.dev0")
79
80
logger = get_logger(__name__)
81
82
83
def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
84
"""
85
Add tokens to the tokenizer and set the initial value of token embeddings
86
"""
87
tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
88
text_encoder.resize_token_embeddings(len(tokenizer))
89
token_embeds = text_encoder.get_input_embeddings().weight.data
90
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
91
if initializer_token:
92
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
93
for i, placeholder_token_id in enumerate(placeholder_token_ids):
94
token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
95
else:
96
for i, placeholder_token_id in enumerate(placeholder_token_ids):
97
token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
98
return placeholder_token
99
100
101
def save_progress(tokenizer, text_encoder, accelerator, save_path):
102
for placeholder_token in tokenizer.token_map:
103
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
104
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]
105
if len(placeholder_token_ids) == 1:
106
learned_embeds = learned_embeds[None]
107
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
108
torch.save(learned_embeds_dict, save_path)
109
110
111
def load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):
112
for placeholder_token in learned_embeds_dict:
113
placeholder_embeds = learned_embeds_dict[placeholder_token]
114
num_vec_per_token = placeholder_embeds.shape[0]
115
placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)
116
add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)
117
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
118
token_embeds = text_encoder.get_input_embeddings().weight.data
119
for i, placeholder_token_id in enumerate(placeholder_token_ids):
120
token_embeds[placeholder_token_id] = placeholder_embeds[i]
121
122
123
def load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):
124
"""
125
Automatic1111's tokens have format
126
{'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833, 0.0030, 0.0057, ..., -0.0264, -0.0616, -0.0529],
127
[ 0.0058, -0.0190, -0.0584, ..., -0.0025, -0.0945, -0.0490],
128
[ 0.0916, 0.0025, 0.0365, ..., -0.0685, -0.0124, 0.0728],
129
[ 0.0812, -0.0199, -0.0100, ..., -0.0581, -0.0780, 0.0254]],
130
requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}
131
"""
132
learned_embeds_dict = {}
133
learned_embeds_dict[placeholder_token] = automatic_dict["string_to_param"]["*"]
134
load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)
135
136
137
def get_mask(tokenizer, accelerator):
138
# Get the mask of the weights that won't change
139
mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)
140
for placeholder_token in tokenizer.token_map:
141
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
142
for i in range(len(placeholder_token_ids)):
143
mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)
144
return mask
145
146
147
def parse_args():
148
parser = argparse.ArgumentParser(description="Simple example of a training script.")
149
parser.add_argument(
150
"--progressive_tokens_max_steps",
151
type=int,
152
default=2000,
153
help="The number of steps until all tokens will be used.",
154
)
155
parser.add_argument(
156
"--progressive_tokens",
157
action="store_true",
158
help="Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.",
159
)
160
parser.add_argument("--vector_shuffle", action="store_true", help="Shuffling tokens durint training")
161
parser.add_argument(
162
"--num_vec_per_token",
163
type=int,
164
default=1,
165
help=(
166
"The number of vectors used to represent the placeholder token. The higher the number, the better the"
167
" result at the cost of editability. This can be fixed by prompt editing."
168
),
169
)
170
parser.add_argument(
171
"--save_steps",
172
type=int,
173
default=500,
174
help="Save learned_embeds.bin every X updates steps.",
175
)
176
parser.add_argument(
177
"--only_save_embeds",
178
action="store_true",
179
default=False,
180
help="Save only the embeddings for the new concept.",
181
)
182
parser.add_argument(
183
"--pretrained_model_name_or_path",
184
type=str,
185
default=None,
186
required=True,
187
help="Path to pretrained model or model identifier from huggingface.co/models.",
188
)
189
parser.add_argument(
190
"--revision",
191
type=str,
192
default=None,
193
required=False,
194
help="Revision of pretrained model identifier from huggingface.co/models.",
195
)
196
parser.add_argument(
197
"--tokenizer_name",
198
type=str,
199
default=None,
200
help="Pretrained tokenizer name or path if not the same as model_name",
201
)
202
parser.add_argument(
203
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
204
)
205
parser.add_argument(
206
"--placeholder_token",
207
type=str,
208
default=None,
209
required=True,
210
help="A token to use as a placeholder for the concept.",
211
)
212
parser.add_argument(
213
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
214
)
215
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
216
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
217
parser.add_argument(
218
"--output_dir",
219
type=str,
220
default="text-inversion-model",
221
help="The output directory where the model predictions and checkpoints will be written.",
222
)
223
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
224
parser.add_argument(
225
"--resolution",
226
type=int,
227
default=512,
228
help=(
229
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
230
" resolution"
231
),
232
)
233
parser.add_argument(
234
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
235
)
236
parser.add_argument(
237
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
238
)
239
parser.add_argument("--num_train_epochs", type=int, default=100)
240
parser.add_argument(
241
"--max_train_steps",
242
type=int,
243
default=5000,
244
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
245
)
246
parser.add_argument(
247
"--gradient_accumulation_steps",
248
type=int,
249
default=1,
250
help="Number of updates steps to accumulate before performing a backward/update pass.",
251
)
252
parser.add_argument(
253
"--gradient_checkpointing",
254
action="store_true",
255
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
256
)
257
parser.add_argument(
258
"--learning_rate",
259
type=float,
260
default=1e-4,
261
help="Initial learning rate (after the potential warmup period) to use.",
262
)
263
parser.add_argument(
264
"--scale_lr",
265
action="store_true",
266
default=False,
267
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
268
)
269
parser.add_argument(
270
"--lr_scheduler",
271
type=str,
272
default="constant",
273
help=(
274
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
275
' "constant", "constant_with_warmup"]'
276
),
277
)
278
parser.add_argument(
279
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
280
)
281
parser.add_argument(
282
"--dataloader_num_workers",
283
type=int,
284
default=0,
285
help=(
286
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
287
),
288
)
289
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
290
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
291
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
292
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
293
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
294
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
295
parser.add_argument(
296
"--hub_model_id",
297
type=str,
298
default=None,
299
help="The name of the repository to keep in sync with the local `output_dir`.",
300
)
301
parser.add_argument(
302
"--logging_dir",
303
type=str,
304
default="logs",
305
help=(
306
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
307
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
308
),
309
)
310
parser.add_argument(
311
"--mixed_precision",
312
type=str,
313
default="no",
314
choices=["no", "fp16", "bf16"],
315
help=(
316
"Whether to use mixed precision. Choose"
317
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
318
"and an Nvidia Ampere GPU."
319
),
320
)
321
parser.add_argument(
322
"--allow_tf32",
323
action="store_true",
324
help=(
325
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
326
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
327
),
328
)
329
parser.add_argument(
330
"--report_to",
331
type=str,
332
default="tensorboard",
333
help=(
334
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
335
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
336
),
337
)
338
parser.add_argument(
339
"--validation_prompt",
340
type=str,
341
default=None,
342
help="A prompt that is used during validation to verify that the model is learning.",
343
)
344
parser.add_argument(
345
"--num_validation_images",
346
type=int,
347
default=4,
348
help="Number of images that should be generated during validation with `validation_prompt`.",
349
)
350
parser.add_argument(
351
"--validation_epochs",
352
type=int,
353
default=50,
354
help=(
355
"Run validation every X epochs. Validation consists of running the prompt"
356
" `args.validation_prompt` multiple times: `args.num_validation_images`"
357
" and logging the images."
358
),
359
)
360
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
361
parser.add_argument(
362
"--checkpointing_steps",
363
type=int,
364
default=500,
365
help=(
366
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
367
" training using `--resume_from_checkpoint`."
368
),
369
)
370
parser.add_argument(
371
"--checkpoints_total_limit",
372
type=int,
373
default=None,
374
help=(
375
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
376
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
377
" for more docs"
378
),
379
)
380
parser.add_argument(
381
"--resume_from_checkpoint",
382
type=str,
383
default=None,
384
help=(
385
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
386
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
387
),
388
)
389
parser.add_argument(
390
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
391
)
392
393
args = parser.parse_args()
394
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
395
if env_local_rank != -1 and env_local_rank != args.local_rank:
396
args.local_rank = env_local_rank
397
398
if args.train_data_dir is None:
399
raise ValueError("You must specify a train data directory.")
400
401
return args
402
403
404
imagenet_templates_small = [
405
"a photo of a {}",
406
"a rendering of a {}",
407
"a cropped photo of the {}",
408
"the photo of a {}",
409
"a photo of a clean {}",
410
"a photo of a dirty {}",
411
"a dark photo of the {}",
412
"a photo of my {}",
413
"a photo of the cool {}",
414
"a close-up photo of a {}",
415
"a bright photo of the {}",
416
"a cropped photo of a {}",
417
"a photo of the {}",
418
"a good photo of the {}",
419
"a photo of one {}",
420
"a close-up photo of the {}",
421
"a rendition of the {}",
422
"a photo of the clean {}",
423
"a rendition of a {}",
424
"a photo of a nice {}",
425
"a good photo of a {}",
426
"a photo of the nice {}",
427
"a photo of the small {}",
428
"a photo of the weird {}",
429
"a photo of the large {}",
430
"a photo of a cool {}",
431
"a photo of a small {}",
432
]
433
434
imagenet_style_templates_small = [
435
"a painting in the style of {}",
436
"a rendering in the style of {}",
437
"a cropped painting in the style of {}",
438
"the painting in the style of {}",
439
"a clean painting in the style of {}",
440
"a dirty painting in the style of {}",
441
"a dark painting in the style of {}",
442
"a picture in the style of {}",
443
"a cool painting in the style of {}",
444
"a close-up painting in the style of {}",
445
"a bright painting in the style of {}",
446
"a cropped painting in the style of {}",
447
"a good painting in the style of {}",
448
"a close-up painting in the style of {}",
449
"a rendition in the style of {}",
450
"a nice painting in the style of {}",
451
"a small painting in the style of {}",
452
"a weird painting in the style of {}",
453
"a large painting in the style of {}",
454
]
455
456
457
class TextualInversionDataset(Dataset):
458
def __init__(
459
self,
460
data_root,
461
tokenizer,
462
learnable_property="object", # [object, style]
463
size=512,
464
repeats=100,
465
interpolation="bicubic",
466
flip_p=0.5,
467
set="train",
468
placeholder_token="*",
469
center_crop=False,
470
vector_shuffle=False,
471
progressive_tokens=False,
472
):
473
self.data_root = data_root
474
self.tokenizer = tokenizer
475
self.learnable_property = learnable_property
476
self.size = size
477
self.placeholder_token = placeholder_token
478
self.center_crop = center_crop
479
self.flip_p = flip_p
480
self.vector_shuffle = vector_shuffle
481
self.progressive_tokens = progressive_tokens
482
self.prop_tokens_to_load = 0
483
484
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
485
486
self.num_images = len(self.image_paths)
487
self._length = self.num_images
488
489
if set == "train":
490
self._length = self.num_images * repeats
491
492
self.interpolation = {
493
"linear": PIL_INTERPOLATION["linear"],
494
"bilinear": PIL_INTERPOLATION["bilinear"],
495
"bicubic": PIL_INTERPOLATION["bicubic"],
496
"lanczos": PIL_INTERPOLATION["lanczos"],
497
}[interpolation]
498
499
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
500
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
501
502
def __len__(self):
503
return self._length
504
505
def __getitem__(self, i):
506
example = {}
507
image = Image.open(self.image_paths[i % self.num_images])
508
509
if not image.mode == "RGB":
510
image = image.convert("RGB")
511
512
placeholder_string = self.placeholder_token
513
text = random.choice(self.templates).format(placeholder_string)
514
515
example["input_ids"] = self.tokenizer.encode(
516
text,
517
padding="max_length",
518
truncation=True,
519
max_length=self.tokenizer.model_max_length,
520
return_tensors="pt",
521
vector_shuffle=self.vector_shuffle,
522
prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,
523
)[0]
524
525
# default to score-sde preprocessing
526
img = np.array(image).astype(np.uint8)
527
528
if self.center_crop:
529
crop = min(img.shape[0], img.shape[1])
530
(
531
h,
532
w,
533
) = (
534
img.shape[0],
535
img.shape[1],
536
)
537
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
538
539
image = Image.fromarray(img)
540
image = image.resize((self.size, self.size), resample=self.interpolation)
541
542
image = self.flip_transform(image)
543
image = np.array(image).astype(np.uint8)
544
image = (image / 127.5 - 1.0).astype(np.float32)
545
546
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
547
return example
548
549
550
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
551
if token is None:
552
token = HfFolder.get_token()
553
if organization is None:
554
username = whoami(token)["name"]
555
return f"{username}/{model_id}"
556
else:
557
return f"{organization}/{model_id}"
558
559
560
def main():
561
args = parse_args()
562
logging_dir = os.path.join(args.output_dir, args.logging_dir)
563
564
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
565
566
accelerator = Accelerator(
567
gradient_accumulation_steps=args.gradient_accumulation_steps,
568
mixed_precision=args.mixed_precision,
569
log_with=args.report_to,
570
logging_dir=logging_dir,
571
project_config=accelerator_project_config,
572
)
573
574
if args.report_to == "wandb":
575
if not is_wandb_available():
576
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
577
import wandb
578
579
# Make one log on every process with the configuration for debugging.
580
logging.basicConfig(
581
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
582
datefmt="%m/%d/%Y %H:%M:%S",
583
level=logging.INFO,
584
)
585
logger.info(accelerator.state, main_process_only=False)
586
if accelerator.is_local_main_process:
587
transformers.utils.logging.set_verbosity_warning()
588
diffusers.utils.logging.set_verbosity_info()
589
else:
590
transformers.utils.logging.set_verbosity_error()
591
diffusers.utils.logging.set_verbosity_error()
592
593
# If passed along, set the training seed now.
594
if args.seed is not None:
595
set_seed(args.seed)
596
597
# Handle the repository creation
598
if accelerator.is_main_process:
599
if args.push_to_hub:
600
if args.hub_model_id is None:
601
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
602
else:
603
repo_name = args.hub_model_id
604
create_repo(repo_name, exist_ok=True, token=args.hub_token)
605
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
606
607
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
608
if "step_*" not in gitignore:
609
gitignore.write("step_*\n")
610
if "epoch_*" not in gitignore:
611
gitignore.write("epoch_*\n")
612
elif args.output_dir is not None:
613
os.makedirs(args.output_dir, exist_ok=True)
614
615
# Load tokenizer
616
if args.tokenizer_name:
617
tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)
618
elif args.pretrained_model_name_or_path:
619
tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
620
621
# Load scheduler and models
622
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
623
text_encoder = CLIPTextModel.from_pretrained(
624
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
625
)
626
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
627
unet = UNet2DConditionModel.from_pretrained(
628
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
629
)
630
if is_xformers_available():
631
try:
632
unet.enable_xformers_memory_efficient_attention()
633
except Exception as e:
634
logger.warning(
635
"Could not enable memory efficient attention. Make sure xformers is installed"
636
f" correctly and a GPU is available: {e}"
637
)
638
add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)
639
640
# Freeze vae and unet
641
vae.requires_grad_(False)
642
unet.requires_grad_(False)
643
# Freeze all parameters except for the token embeddings in text encoder
644
text_encoder.text_model.encoder.requires_grad_(False)
645
text_encoder.text_model.final_layer_norm.requires_grad_(False)
646
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
647
648
if args.gradient_checkpointing:
649
# Keep unet in train mode if we are using gradient checkpointing to save memory.
650
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
651
unet.train()
652
text_encoder.gradient_checkpointing_enable()
653
unet.enable_gradient_checkpointing()
654
655
if args.enable_xformers_memory_efficient_attention:
656
if is_xformers_available():
657
import xformers
658
659
xformers_version = version.parse(xformers.__version__)
660
if xformers_version == version.parse("0.0.16"):
661
logger.warn(
662
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
663
)
664
unet.enable_xformers_memory_efficient_attention()
665
else:
666
raise ValueError("xformers is not available. Make sure it is installed correctly")
667
668
# Enable TF32 for faster training on Ampere GPUs,
669
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
670
if args.allow_tf32:
671
torch.backends.cuda.matmul.allow_tf32 = True
672
673
if args.scale_lr:
674
args.learning_rate = (
675
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
676
)
677
678
# Initialize the optimizer
679
optimizer = torch.optim.AdamW(
680
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
681
lr=args.learning_rate,
682
betas=(args.adam_beta1, args.adam_beta2),
683
weight_decay=args.adam_weight_decay,
684
eps=args.adam_epsilon,
685
)
686
687
# Dataset and DataLoaders creation:
688
train_dataset = TextualInversionDataset(
689
data_root=args.train_data_dir,
690
tokenizer=tokenizer,
691
size=args.resolution,
692
placeholder_token=args.placeholder_token,
693
repeats=args.repeats,
694
learnable_property=args.learnable_property,
695
center_crop=args.center_crop,
696
set="train",
697
)
698
train_dataloader = torch.utils.data.DataLoader(
699
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
700
)
701
702
# Scheduler and math around the number of training steps.
703
overrode_max_train_steps = False
704
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
705
if args.max_train_steps is None:
706
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
707
overrode_max_train_steps = True
708
709
lr_scheduler = get_scheduler(
710
args.lr_scheduler,
711
optimizer=optimizer,
712
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
713
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
714
)
715
716
# Prepare everything with our `accelerator`.
717
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
718
text_encoder, optimizer, train_dataloader, lr_scheduler
719
)
720
721
# For mixed precision training we cast the unet and vae weights to half-precision
722
# as these models are only used for inference, keeping weights in full precision is not required.
723
weight_dtype = torch.float32
724
if accelerator.mixed_precision == "fp16":
725
weight_dtype = torch.float16
726
elif accelerator.mixed_precision == "bf16":
727
weight_dtype = torch.bfloat16
728
729
# Move vae and unet to device and cast to weight_dtype
730
unet.to(accelerator.device, dtype=weight_dtype)
731
vae.to(accelerator.device, dtype=weight_dtype)
732
733
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
734
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
735
if overrode_max_train_steps:
736
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
737
# Afterwards we recalculate our number of training epochs
738
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
739
740
# We need to initialize the trackers we use, and also store our configuration.
741
# The trackers initializes automatically on the main process.
742
if accelerator.is_main_process:
743
accelerator.init_trackers("textual_inversion", config=vars(args))
744
745
# Train!
746
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
747
748
logger.info("***** Running training *****")
749
logger.info(f" Num examples = {len(train_dataset)}")
750
logger.info(f" Num Epochs = {args.num_train_epochs}")
751
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
752
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
753
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
754
logger.info(f" Total optimization steps = {args.max_train_steps}")
755
global_step = 0
756
first_epoch = 0
757
758
# Potentially load in the weights and states from a previous save
759
if args.resume_from_checkpoint:
760
if args.resume_from_checkpoint != "latest":
761
path = os.path.basename(args.resume_from_checkpoint)
762
else:
763
# Get the most recent checkpoint
764
dirs = os.listdir(args.output_dir)
765
dirs = [d for d in dirs if d.startswith("checkpoint")]
766
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
767
path = dirs[-1] if len(dirs) > 0 else None
768
769
if path is None:
770
accelerator.print(
771
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
772
)
773
args.resume_from_checkpoint = None
774
else:
775
accelerator.print(f"Resuming from checkpoint {path}")
776
accelerator.load_state(os.path.join(args.output_dir, path))
777
global_step = int(path.split("-")[1])
778
779
resume_global_step = global_step * args.gradient_accumulation_steps
780
first_epoch = global_step // num_update_steps_per_epoch
781
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
782
783
# Only show the progress bar once on each machine.
784
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
785
progress_bar.set_description("Steps")
786
787
# keep original embeddings as reference
788
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
789
790
for epoch in range(first_epoch, args.num_train_epochs):
791
text_encoder.train()
792
for step, batch in enumerate(train_dataloader):
793
# Skip steps until we reach the resumed step
794
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
795
if step % args.gradient_accumulation_steps == 0:
796
progress_bar.update(1)
797
continue
798
if args.progressive_tokens:
799
train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps
800
801
with accelerator.accumulate(text_encoder):
802
# Convert images to latent space
803
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
804
latents = latents * vae.config.scaling_factor
805
806
# Sample noise that we'll add to the latents
807
noise = torch.randn_like(latents)
808
bsz = latents.shape[0]
809
# Sample a random timestep for each image
810
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
811
timesteps = timesteps.long()
812
813
# Add noise to the latents according to the noise magnitude at each timestep
814
# (this is the forward diffusion process)
815
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
816
817
# Get the text embedding for conditioning
818
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
819
820
# Predict the noise residual
821
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
822
823
# Get the target for loss depending on the prediction type
824
if noise_scheduler.config.prediction_type == "epsilon":
825
target = noise
826
elif noise_scheduler.config.prediction_type == "v_prediction":
827
target = noise_scheduler.get_velocity(latents, noise, timesteps)
828
else:
829
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
830
831
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
832
833
accelerator.backward(loss)
834
835
optimizer.step()
836
lr_scheduler.step()
837
optimizer.zero_grad()
838
839
# Let's make sure we don't update any embedding weights besides the newly added token
840
index_no_updates = get_mask(tokenizer, accelerator)
841
with torch.no_grad():
842
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
843
index_no_updates
844
] = orig_embeds_params[index_no_updates]
845
846
# Checks if the accelerator has performed an optimization step behind the scenes
847
if accelerator.sync_gradients:
848
progress_bar.update(1)
849
global_step += 1
850
if global_step % args.save_steps == 0:
851
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
852
save_progress(tokenizer, text_encoder, accelerator, save_path)
853
854
if global_step % args.checkpointing_steps == 0:
855
if accelerator.is_main_process:
856
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
857
accelerator.save_state(save_path)
858
logger.info(f"Saved state to {save_path}")
859
860
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
861
progress_bar.set_postfix(**logs)
862
accelerator.log(logs, step=global_step)
863
864
if global_step >= args.max_train_steps:
865
break
866
867
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
868
logger.info(
869
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
870
f" {args.validation_prompt}."
871
)
872
# create pipeline (note: unet and vae are loaded again in float32)
873
pipeline = DiffusionPipeline.from_pretrained(
874
args.pretrained_model_name_or_path,
875
text_encoder=accelerator.unwrap_model(text_encoder),
876
tokenizer=tokenizer,
877
unet=unet,
878
vae=vae,
879
revision=args.revision,
880
torch_dtype=weight_dtype,
881
)
882
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
883
pipeline = pipeline.to(accelerator.device)
884
pipeline.set_progress_bar_config(disable=True)
885
886
# run inference
887
generator = (
888
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
889
)
890
images = []
891
for _ in range(args.num_validation_images):
892
with torch.autocast("cuda"):
893
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
894
images.append(image)
895
896
for tracker in accelerator.trackers:
897
if tracker.name == "tensorboard":
898
np_images = np.stack([np.asarray(img) for img in images])
899
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
900
if tracker.name == "wandb":
901
tracker.log(
902
{
903
"validation": [
904
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
905
for i, image in enumerate(images)
906
]
907
}
908
)
909
910
del pipeline
911
torch.cuda.empty_cache()
912
913
# Create the pipeline using using the trained modules and save it.
914
accelerator.wait_for_everyone()
915
if accelerator.is_main_process:
916
if args.push_to_hub and args.only_save_embeds:
917
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
918
save_full_model = True
919
else:
920
save_full_model = not args.only_save_embeds
921
if save_full_model:
922
pipeline = StableDiffusionPipeline.from_pretrained(
923
args.pretrained_model_name_or_path,
924
text_encoder=accelerator.unwrap_model(text_encoder),
925
vae=vae,
926
unet=unet,
927
tokenizer=tokenizer,
928
)
929
pipeline.save_pretrained(args.output_dir)
930
# Save the newly trained embeddings
931
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
932
save_progress(tokenizer, text_encoder, accelerator, save_path)
933
934
if args.push_to_hub:
935
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
936
937
accelerator.end_training()
938
939
940
if __name__ == "__main__":
941
main()
942
943