Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/text_to_image/train_text_to_image.py
1448 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
16
import argparse
17
import logging
18
import math
19
import os
20
import random
21
from pathlib import Path
22
from typing import Optional
23
24
import accelerate
25
import datasets
26
import numpy as np
27
import torch
28
import torch.nn.functional as F
29
import torch.utils.checkpoint
30
import transformers
31
from accelerate import Accelerator
32
from accelerate.logging import get_logger
33
from accelerate.utils import ProjectConfiguration, set_seed
34
from datasets import load_dataset
35
from huggingface_hub import HfFolder, Repository, create_repo, whoami
36
from packaging import version
37
from torchvision import transforms
38
from tqdm.auto import tqdm
39
from transformers import CLIPTextModel, CLIPTokenizer
40
41
import diffusers
42
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
43
from diffusers.optimization import get_scheduler
44
from diffusers.training_utils import EMAModel
45
from diffusers.utils import check_min_version, deprecate
46
from diffusers.utils.import_utils import is_xformers_available
47
48
49
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
50
check_min_version("0.15.0.dev0")
51
52
logger = get_logger(__name__, log_level="INFO")
53
54
55
def parse_args():
56
parser = argparse.ArgumentParser(description="Simple example of a training script.")
57
parser.add_argument(
58
"--pretrained_model_name_or_path",
59
type=str,
60
default=None,
61
required=True,
62
help="Path to pretrained model or model identifier from huggingface.co/models.",
63
)
64
parser.add_argument(
65
"--revision",
66
type=str,
67
default=None,
68
required=False,
69
help="Revision of pretrained model identifier from huggingface.co/models.",
70
)
71
parser.add_argument(
72
"--dataset_name",
73
type=str,
74
default=None,
75
help=(
76
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
77
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
78
" or to a folder containing files that 🤗 Datasets can understand."
79
),
80
)
81
parser.add_argument(
82
"--dataset_config_name",
83
type=str,
84
default=None,
85
help="The config of the Dataset, leave as None if there's only one config.",
86
)
87
parser.add_argument(
88
"--train_data_dir",
89
type=str,
90
default=None,
91
help=(
92
"A folder containing the training data. Folder contents must follow the structure described in"
93
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
94
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
95
),
96
)
97
parser.add_argument(
98
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
99
)
100
parser.add_argument(
101
"--caption_column",
102
type=str,
103
default="text",
104
help="The column of the dataset containing a caption or a list of captions.",
105
)
106
parser.add_argument(
107
"--max_train_samples",
108
type=int,
109
default=None,
110
help=(
111
"For debugging purposes or quicker training, truncate the number of training examples to this "
112
"value if set."
113
),
114
)
115
parser.add_argument(
116
"--output_dir",
117
type=str,
118
default="sd-model-finetuned",
119
help="The output directory where the model predictions and checkpoints will be written.",
120
)
121
parser.add_argument(
122
"--cache_dir",
123
type=str,
124
default=None,
125
help="The directory where the downloaded models and datasets will be stored.",
126
)
127
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
128
parser.add_argument(
129
"--resolution",
130
type=int,
131
default=512,
132
help=(
133
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
134
" resolution"
135
),
136
)
137
parser.add_argument(
138
"--center_crop",
139
default=False,
140
action="store_true",
141
help=(
142
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
143
" cropped. The images will be resized to the resolution first before cropping."
144
),
145
)
146
parser.add_argument(
147
"--random_flip",
148
action="store_true",
149
help="whether to randomly flip images horizontally",
150
)
151
parser.add_argument(
152
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
153
)
154
parser.add_argument("--num_train_epochs", type=int, default=100)
155
parser.add_argument(
156
"--max_train_steps",
157
type=int,
158
default=None,
159
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
160
)
161
parser.add_argument(
162
"--gradient_accumulation_steps",
163
type=int,
164
default=1,
165
help="Number of updates steps to accumulate before performing a backward/update pass.",
166
)
167
parser.add_argument(
168
"--gradient_checkpointing",
169
action="store_true",
170
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
171
)
172
parser.add_argument(
173
"--learning_rate",
174
type=float,
175
default=1e-4,
176
help="Initial learning rate (after the potential warmup period) to use.",
177
)
178
parser.add_argument(
179
"--scale_lr",
180
action="store_true",
181
default=False,
182
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
183
)
184
parser.add_argument(
185
"--lr_scheduler",
186
type=str,
187
default="constant",
188
help=(
189
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
190
' "constant", "constant_with_warmup"]'
191
),
192
)
193
parser.add_argument(
194
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
195
)
196
parser.add_argument(
197
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
198
)
199
parser.add_argument(
200
"--allow_tf32",
201
action="store_true",
202
help=(
203
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
204
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
205
),
206
)
207
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
208
parser.add_argument(
209
"--non_ema_revision",
210
type=str,
211
default=None,
212
required=False,
213
help=(
214
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
215
" remote repository specified with --pretrained_model_name_or_path."
216
),
217
)
218
parser.add_argument(
219
"--dataloader_num_workers",
220
type=int,
221
default=0,
222
help=(
223
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
224
),
225
)
226
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
227
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
228
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
229
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
230
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
231
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
232
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
233
parser.add_argument(
234
"--hub_model_id",
235
type=str,
236
default=None,
237
help="The name of the repository to keep in sync with the local `output_dir`.",
238
)
239
parser.add_argument(
240
"--logging_dir",
241
type=str,
242
default="logs",
243
help=(
244
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
245
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
246
),
247
)
248
parser.add_argument(
249
"--mixed_precision",
250
type=str,
251
default=None,
252
choices=["no", "fp16", "bf16"],
253
help=(
254
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
255
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
256
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
257
),
258
)
259
parser.add_argument(
260
"--report_to",
261
type=str,
262
default="tensorboard",
263
help=(
264
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
265
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
266
),
267
)
268
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
269
parser.add_argument(
270
"--checkpointing_steps",
271
type=int,
272
default=500,
273
help=(
274
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
275
" training using `--resume_from_checkpoint`."
276
),
277
)
278
parser.add_argument(
279
"--checkpoints_total_limit",
280
type=int,
281
default=None,
282
help=(
283
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
284
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
285
" for more docs"
286
),
287
)
288
parser.add_argument(
289
"--resume_from_checkpoint",
290
type=str,
291
default=None,
292
help=(
293
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
294
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
295
),
296
)
297
parser.add_argument(
298
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
299
)
300
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
301
302
args = parser.parse_args()
303
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
304
if env_local_rank != -1 and env_local_rank != args.local_rank:
305
args.local_rank = env_local_rank
306
307
# Sanity checks
308
if args.dataset_name is None and args.train_data_dir is None:
309
raise ValueError("Need either a dataset name or a training folder.")
310
311
# default to using the same revision for the non-ema model if not specified
312
if args.non_ema_revision is None:
313
args.non_ema_revision = args.revision
314
315
return args
316
317
318
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
319
if token is None:
320
token = HfFolder.get_token()
321
if organization is None:
322
username = whoami(token)["name"]
323
return f"{username}/{model_id}"
324
else:
325
return f"{organization}/{model_id}"
326
327
328
dataset_name_mapping = {
329
"lambdalabs/pokemon-blip-captions": ("image", "text"),
330
}
331
332
333
def main():
334
args = parse_args()
335
336
if args.non_ema_revision is not None:
337
deprecate(
338
"non_ema_revision!=None",
339
"0.15.0",
340
message=(
341
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
342
" use `--variant=non_ema` instead."
343
),
344
)
345
logging_dir = os.path.join(args.output_dir, args.logging_dir)
346
347
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
348
349
accelerator = Accelerator(
350
gradient_accumulation_steps=args.gradient_accumulation_steps,
351
mixed_precision=args.mixed_precision,
352
log_with=args.report_to,
353
logging_dir=logging_dir,
354
project_config=accelerator_project_config,
355
)
356
357
# Make one log on every process with the configuration for debugging.
358
logging.basicConfig(
359
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
360
datefmt="%m/%d/%Y %H:%M:%S",
361
level=logging.INFO,
362
)
363
logger.info(accelerator.state, main_process_only=False)
364
if accelerator.is_local_main_process:
365
datasets.utils.logging.set_verbosity_warning()
366
transformers.utils.logging.set_verbosity_warning()
367
diffusers.utils.logging.set_verbosity_info()
368
else:
369
datasets.utils.logging.set_verbosity_error()
370
transformers.utils.logging.set_verbosity_error()
371
diffusers.utils.logging.set_verbosity_error()
372
373
# If passed along, set the training seed now.
374
if args.seed is not None:
375
set_seed(args.seed)
376
377
# Handle the repository creation
378
if accelerator.is_main_process:
379
if args.push_to_hub:
380
if args.hub_model_id is None:
381
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
382
else:
383
repo_name = args.hub_model_id
384
create_repo(repo_name, exist_ok=True, token=args.hub_token)
385
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
386
387
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
388
if "step_*" not in gitignore:
389
gitignore.write("step_*\n")
390
if "epoch_*" not in gitignore:
391
gitignore.write("epoch_*\n")
392
elif args.output_dir is not None:
393
os.makedirs(args.output_dir, exist_ok=True)
394
395
# Load scheduler, tokenizer and models.
396
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
397
tokenizer = CLIPTokenizer.from_pretrained(
398
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
399
)
400
text_encoder = CLIPTextModel.from_pretrained(
401
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
402
)
403
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
404
unet = UNet2DConditionModel.from_pretrained(
405
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
406
)
407
408
# Freeze vae and text_encoder
409
vae.requires_grad_(False)
410
text_encoder.requires_grad_(False)
411
412
# Create EMA for the unet.
413
if args.use_ema:
414
ema_unet = UNet2DConditionModel.from_pretrained(
415
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
416
)
417
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
418
419
if args.enable_xformers_memory_efficient_attention:
420
if is_xformers_available():
421
import xformers
422
423
xformers_version = version.parse(xformers.__version__)
424
if xformers_version == version.parse("0.0.16"):
425
logger.warn(
426
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
427
)
428
unet.enable_xformers_memory_efficient_attention()
429
else:
430
raise ValueError("xformers is not available. Make sure it is installed correctly")
431
432
# `accelerate` 0.16.0 will have better support for customized saving
433
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
434
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
435
def save_model_hook(models, weights, output_dir):
436
if args.use_ema:
437
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
438
439
for i, model in enumerate(models):
440
model.save_pretrained(os.path.join(output_dir, "unet"))
441
442
# make sure to pop weight so that corresponding model is not saved again
443
weights.pop()
444
445
def load_model_hook(models, input_dir):
446
if args.use_ema:
447
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
448
ema_unet.load_state_dict(load_model.state_dict())
449
ema_unet.to(accelerator.device)
450
del load_model
451
452
for i in range(len(models)):
453
# pop models so that they are not loaded again
454
model = models.pop()
455
456
# load diffusers style into model
457
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
458
model.register_to_config(**load_model.config)
459
460
model.load_state_dict(load_model.state_dict())
461
del load_model
462
463
accelerator.register_save_state_pre_hook(save_model_hook)
464
accelerator.register_load_state_pre_hook(load_model_hook)
465
466
if args.gradient_checkpointing:
467
unet.enable_gradient_checkpointing()
468
469
# Enable TF32 for faster training on Ampere GPUs,
470
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
471
if args.allow_tf32:
472
torch.backends.cuda.matmul.allow_tf32 = True
473
474
if args.scale_lr:
475
args.learning_rate = (
476
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
477
)
478
479
# Initialize the optimizer
480
if args.use_8bit_adam:
481
try:
482
import bitsandbytes as bnb
483
except ImportError:
484
raise ImportError(
485
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
486
)
487
488
optimizer_cls = bnb.optim.AdamW8bit
489
else:
490
optimizer_cls = torch.optim.AdamW
491
492
optimizer = optimizer_cls(
493
unet.parameters(),
494
lr=args.learning_rate,
495
betas=(args.adam_beta1, args.adam_beta2),
496
weight_decay=args.adam_weight_decay,
497
eps=args.adam_epsilon,
498
)
499
500
# Get the datasets: you can either provide your own training and evaluation files (see below)
501
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
502
503
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
504
# download the dataset.
505
if args.dataset_name is not None:
506
# Downloading and loading a dataset from the hub.
507
dataset = load_dataset(
508
args.dataset_name,
509
args.dataset_config_name,
510
cache_dir=args.cache_dir,
511
)
512
else:
513
data_files = {}
514
if args.train_data_dir is not None:
515
data_files["train"] = os.path.join(args.train_data_dir, "**")
516
dataset = load_dataset(
517
"imagefolder",
518
data_files=data_files,
519
cache_dir=args.cache_dir,
520
)
521
# See more about loading custom images at
522
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
523
524
# Preprocessing the datasets.
525
# We need to tokenize inputs and targets.
526
column_names = dataset["train"].column_names
527
528
# 6. Get the column names for input/target.
529
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
530
if args.image_column is None:
531
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
532
else:
533
image_column = args.image_column
534
if image_column not in column_names:
535
raise ValueError(
536
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
537
)
538
if args.caption_column is None:
539
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
540
else:
541
caption_column = args.caption_column
542
if caption_column not in column_names:
543
raise ValueError(
544
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
545
)
546
547
# Preprocessing the datasets.
548
# We need to tokenize input captions and transform the images.
549
def tokenize_captions(examples, is_train=True):
550
captions = []
551
for caption in examples[caption_column]:
552
if isinstance(caption, str):
553
captions.append(caption)
554
elif isinstance(caption, (list, np.ndarray)):
555
# take a random caption if there are multiple
556
captions.append(random.choice(caption) if is_train else caption[0])
557
else:
558
raise ValueError(
559
f"Caption column `{caption_column}` should contain either strings or lists of strings."
560
)
561
inputs = tokenizer(
562
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
563
)
564
return inputs.input_ids
565
566
# Preprocessing the datasets.
567
train_transforms = transforms.Compose(
568
[
569
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
570
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
571
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
572
transforms.ToTensor(),
573
transforms.Normalize([0.5], [0.5]),
574
]
575
)
576
577
def preprocess_train(examples):
578
images = [image.convert("RGB") for image in examples[image_column]]
579
examples["pixel_values"] = [train_transforms(image) for image in images]
580
examples["input_ids"] = tokenize_captions(examples)
581
return examples
582
583
with accelerator.main_process_first():
584
if args.max_train_samples is not None:
585
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
586
# Set the training transforms
587
train_dataset = dataset["train"].with_transform(preprocess_train)
588
589
def collate_fn(examples):
590
pixel_values = torch.stack([example["pixel_values"] for example in examples])
591
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
592
input_ids = torch.stack([example["input_ids"] for example in examples])
593
return {"pixel_values": pixel_values, "input_ids": input_ids}
594
595
# DataLoaders creation:
596
train_dataloader = torch.utils.data.DataLoader(
597
train_dataset,
598
shuffle=True,
599
collate_fn=collate_fn,
600
batch_size=args.train_batch_size,
601
num_workers=args.dataloader_num_workers,
602
)
603
604
# Scheduler and math around the number of training steps.
605
overrode_max_train_steps = False
606
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
607
if args.max_train_steps is None:
608
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
609
overrode_max_train_steps = True
610
611
lr_scheduler = get_scheduler(
612
args.lr_scheduler,
613
optimizer=optimizer,
614
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
615
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
616
)
617
618
# Prepare everything with our `accelerator`.
619
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
620
unet, optimizer, train_dataloader, lr_scheduler
621
)
622
623
if args.use_ema:
624
ema_unet.to(accelerator.device)
625
626
# For mixed precision training we cast the text_encoder and vae weights to half-precision
627
# as these models are only used for inference, keeping weights in full precision is not required.
628
weight_dtype = torch.float32
629
if accelerator.mixed_precision == "fp16":
630
weight_dtype = torch.float16
631
elif accelerator.mixed_precision == "bf16":
632
weight_dtype = torch.bfloat16
633
634
# Move text_encode and vae to gpu and cast to weight_dtype
635
text_encoder.to(accelerator.device, dtype=weight_dtype)
636
vae.to(accelerator.device, dtype=weight_dtype)
637
638
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
639
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
640
if overrode_max_train_steps:
641
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
642
# Afterwards we recalculate our number of training epochs
643
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
644
645
# We need to initialize the trackers we use, and also store our configuration.
646
# The trackers initializes automatically on the main process.
647
if accelerator.is_main_process:
648
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
649
650
# Train!
651
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
652
653
logger.info("***** Running training *****")
654
logger.info(f" Num examples = {len(train_dataset)}")
655
logger.info(f" Num Epochs = {args.num_train_epochs}")
656
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
657
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
658
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
659
logger.info(f" Total optimization steps = {args.max_train_steps}")
660
global_step = 0
661
first_epoch = 0
662
663
# Potentially load in the weights and states from a previous save
664
if args.resume_from_checkpoint:
665
if args.resume_from_checkpoint != "latest":
666
path = os.path.basename(args.resume_from_checkpoint)
667
else:
668
# Get the most recent checkpoint
669
dirs = os.listdir(args.output_dir)
670
dirs = [d for d in dirs if d.startswith("checkpoint")]
671
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
672
path = dirs[-1] if len(dirs) > 0 else None
673
674
if path is None:
675
accelerator.print(
676
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
677
)
678
args.resume_from_checkpoint = None
679
else:
680
accelerator.print(f"Resuming from checkpoint {path}")
681
accelerator.load_state(os.path.join(args.output_dir, path))
682
global_step = int(path.split("-")[1])
683
684
resume_global_step = global_step * args.gradient_accumulation_steps
685
first_epoch = global_step // num_update_steps_per_epoch
686
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
687
688
# Only show the progress bar once on each machine.
689
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
690
progress_bar.set_description("Steps")
691
692
for epoch in range(first_epoch, args.num_train_epochs):
693
unet.train()
694
train_loss = 0.0
695
for step, batch in enumerate(train_dataloader):
696
# Skip steps until we reach the resumed step
697
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
698
if step % args.gradient_accumulation_steps == 0:
699
progress_bar.update(1)
700
continue
701
702
with accelerator.accumulate(unet):
703
# Convert images to latent space
704
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
705
latents = latents * vae.config.scaling_factor
706
707
# Sample noise that we'll add to the latents
708
noise = torch.randn_like(latents)
709
if args.noise_offset:
710
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
711
noise += args.noise_offset * torch.randn(
712
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
713
)
714
715
bsz = latents.shape[0]
716
# Sample a random timestep for each image
717
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
718
timesteps = timesteps.long()
719
720
# Add noise to the latents according to the noise magnitude at each timestep
721
# (this is the forward diffusion process)
722
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
723
724
# Get the text embedding for conditioning
725
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
726
727
# Get the target for loss depending on the prediction type
728
if noise_scheduler.config.prediction_type == "epsilon":
729
target = noise
730
elif noise_scheduler.config.prediction_type == "v_prediction":
731
target = noise_scheduler.get_velocity(latents, noise, timesteps)
732
else:
733
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
734
735
# Predict the noise residual and compute loss
736
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
737
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
738
739
# Gather the losses across all processes for logging (if we use distributed training).
740
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
741
train_loss += avg_loss.item() / args.gradient_accumulation_steps
742
743
# Backpropagate
744
accelerator.backward(loss)
745
if accelerator.sync_gradients:
746
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
747
optimizer.step()
748
lr_scheduler.step()
749
optimizer.zero_grad()
750
751
# Checks if the accelerator has performed an optimization step behind the scenes
752
if accelerator.sync_gradients:
753
if args.use_ema:
754
ema_unet.step(unet.parameters())
755
progress_bar.update(1)
756
global_step += 1
757
accelerator.log({"train_loss": train_loss}, step=global_step)
758
train_loss = 0.0
759
760
if global_step % args.checkpointing_steps == 0:
761
if accelerator.is_main_process:
762
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
763
accelerator.save_state(save_path)
764
logger.info(f"Saved state to {save_path}")
765
766
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
767
progress_bar.set_postfix(**logs)
768
769
if global_step >= args.max_train_steps:
770
break
771
772
# Create the pipeline using the trained modules and save it.
773
accelerator.wait_for_everyone()
774
if accelerator.is_main_process:
775
unet = accelerator.unwrap_model(unet)
776
if args.use_ema:
777
ema_unet.copy_to(unet.parameters())
778
779
pipeline = StableDiffusionPipeline.from_pretrained(
780
args.pretrained_model_name_or_path,
781
text_encoder=text_encoder,
782
vae=vae,
783
unet=unet,
784
revision=args.revision,
785
)
786
pipeline.save_pretrained(args.output_dir)
787
788
if args.push_to_hub:
789
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
790
791
accelerator.end_training()
792
793
794
if __name__ == "__main__":
795
main()
796
797