Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/controlnet/train_controlnet.py
1448 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
16
import argparse
17
import logging
18
import math
19
import os
20
import random
21
from pathlib import Path
22
from typing import Optional
23
24
import accelerate
25
import numpy as np
26
import torch
27
import torch.nn.functional as F
28
import torch.utils.checkpoint
29
import transformers
30
from accelerate import Accelerator
31
from accelerate.logging import get_logger
32
from accelerate.utils import ProjectConfiguration, set_seed
33
from datasets import load_dataset
34
from huggingface_hub import HfFolder, Repository, create_repo, whoami
35
from packaging import version
36
from PIL import Image
37
from torchvision import transforms
38
from tqdm.auto import tqdm
39
from transformers import AutoTokenizer, PretrainedConfig
40
41
import diffusers
42
from diffusers import (
43
AutoencoderKL,
44
ControlNetModel,
45
DDPMScheduler,
46
StableDiffusionControlNetPipeline,
47
UNet2DConditionModel,
48
UniPCMultistepScheduler,
49
)
50
from diffusers.optimization import get_scheduler
51
from diffusers.utils import check_min_version, is_wandb_available
52
from diffusers.utils.import_utils import is_xformers_available
53
54
55
if is_wandb_available():
56
import wandb
57
58
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
check_min_version("0.15.0.dev0")
60
61
logger = get_logger(__name__)
62
63
64
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
65
logger.info("Running validation... ")
66
67
controlnet = accelerator.unwrap_model(controlnet)
68
69
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
70
args.pretrained_model_name_or_path,
71
vae=vae,
72
text_encoder=text_encoder,
73
tokenizer=tokenizer,
74
unet=unet,
75
controlnet=controlnet,
76
safety_checker=None,
77
revision=args.revision,
78
torch_dtype=weight_dtype,
79
)
80
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
81
pipeline = pipeline.to(accelerator.device)
82
pipeline.set_progress_bar_config(disable=True)
83
84
if args.enable_xformers_memory_efficient_attention:
85
pipeline.enable_xformers_memory_efficient_attention()
86
87
if args.seed is None:
88
generator = None
89
else:
90
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
91
92
if len(args.validation_image) == len(args.validation_prompt):
93
validation_images = args.validation_image
94
validation_prompts = args.validation_prompt
95
elif len(args.validation_image) == 1:
96
validation_images = args.validation_image * len(args.validation_prompt)
97
validation_prompts = args.validation_prompt
98
elif len(args.validation_prompt) == 1:
99
validation_images = args.validation_image
100
validation_prompts = args.validation_prompt * len(args.validation_image)
101
else:
102
raise ValueError(
103
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
104
)
105
106
image_logs = []
107
108
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
109
validation_image = Image.open(validation_image)
110
111
images = []
112
113
for _ in range(args.num_validation_images):
114
with torch.autocast("cuda"):
115
image = pipeline(
116
validation_prompt, validation_image, num_inference_steps=20, generator=generator
117
).images[0]
118
119
images.append(image)
120
121
image_logs.append(
122
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
123
)
124
125
for tracker in accelerator.trackers:
126
if tracker.name == "tensorboard":
127
for log in image_logs:
128
images = log["images"]
129
validation_prompt = log["validation_prompt"]
130
validation_image = log["validation_image"]
131
132
formatted_images = []
133
134
formatted_images.append(np.asarray(validation_image))
135
136
for image in images:
137
formatted_images.append(np.asarray(image))
138
139
formatted_images = np.stack(formatted_images)
140
141
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
142
elif tracker.name == "wandb":
143
formatted_images = []
144
145
for log in image_logs:
146
images = log["images"]
147
validation_prompt = log["validation_prompt"]
148
validation_image = log["validation_image"]
149
150
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
151
152
for image in images:
153
image = wandb.Image(image, caption=validation_prompt)
154
formatted_images.append(image)
155
156
tracker.log({"validation": formatted_images})
157
else:
158
logger.warn(f"image logging not implemented for {tracker.name}")
159
160
161
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
162
text_encoder_config = PretrainedConfig.from_pretrained(
163
pretrained_model_name_or_path,
164
subfolder="text_encoder",
165
revision=revision,
166
)
167
model_class = text_encoder_config.architectures[0]
168
169
if model_class == "CLIPTextModel":
170
from transformers import CLIPTextModel
171
172
return CLIPTextModel
173
elif model_class == "RobertaSeriesModelWithTransformation":
174
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
175
176
return RobertaSeriesModelWithTransformation
177
else:
178
raise ValueError(f"{model_class} is not supported.")
179
180
181
def parse_args(input_args=None):
182
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
183
parser.add_argument(
184
"--pretrained_model_name_or_path",
185
type=str,
186
default=None,
187
required=True,
188
help="Path to pretrained model or model identifier from huggingface.co/models.",
189
)
190
parser.add_argument(
191
"--controlnet_model_name_or_path",
192
type=str,
193
default=None,
194
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
195
" If not specified controlnet weights are initialized from unet.",
196
)
197
parser.add_argument(
198
"--revision",
199
type=str,
200
default=None,
201
required=False,
202
help=(
203
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
204
" float32 precision."
205
),
206
)
207
parser.add_argument(
208
"--tokenizer_name",
209
type=str,
210
default=None,
211
help="Pretrained tokenizer name or path if not the same as model_name",
212
)
213
parser.add_argument(
214
"--output_dir",
215
type=str,
216
default="controlnet-model",
217
help="The output directory where the model predictions and checkpoints will be written.",
218
)
219
parser.add_argument(
220
"--cache_dir",
221
type=str,
222
default=None,
223
help="The directory where the downloaded models and datasets will be stored.",
224
)
225
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
226
parser.add_argument(
227
"--resolution",
228
type=int,
229
default=512,
230
help=(
231
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
232
" resolution"
233
),
234
)
235
parser.add_argument(
236
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
237
)
238
parser.add_argument("--num_train_epochs", type=int, default=1)
239
parser.add_argument(
240
"--max_train_steps",
241
type=int,
242
default=None,
243
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
244
)
245
parser.add_argument(
246
"--checkpointing_steps",
247
type=int,
248
default=500,
249
help=(
250
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
251
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
252
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
253
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
254
"instructions."
255
),
256
)
257
parser.add_argument(
258
"--checkpoints_total_limit",
259
type=int,
260
default=None,
261
help=(
262
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
263
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
264
" for more details"
265
),
266
)
267
parser.add_argument(
268
"--resume_from_checkpoint",
269
type=str,
270
default=None,
271
help=(
272
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
273
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
274
),
275
)
276
parser.add_argument(
277
"--gradient_accumulation_steps",
278
type=int,
279
default=1,
280
help="Number of updates steps to accumulate before performing a backward/update pass.",
281
)
282
parser.add_argument(
283
"--gradient_checkpointing",
284
action="store_true",
285
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
286
)
287
parser.add_argument(
288
"--learning_rate",
289
type=float,
290
default=5e-6,
291
help="Initial learning rate (after the potential warmup period) to use.",
292
)
293
parser.add_argument(
294
"--scale_lr",
295
action="store_true",
296
default=False,
297
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
298
)
299
parser.add_argument(
300
"--lr_scheduler",
301
type=str,
302
default="constant",
303
help=(
304
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
305
' "constant", "constant_with_warmup"]'
306
),
307
)
308
parser.add_argument(
309
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
310
)
311
parser.add_argument(
312
"--lr_num_cycles",
313
type=int,
314
default=1,
315
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
316
)
317
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
318
parser.add_argument(
319
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
320
)
321
parser.add_argument(
322
"--dataloader_num_workers",
323
type=int,
324
default=0,
325
help=(
326
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
327
),
328
)
329
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
330
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
331
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
332
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
333
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
334
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
335
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
336
parser.add_argument(
337
"--hub_model_id",
338
type=str,
339
default=None,
340
help="The name of the repository to keep in sync with the local `output_dir`.",
341
)
342
parser.add_argument(
343
"--logging_dir",
344
type=str,
345
default="logs",
346
help=(
347
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
348
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
349
),
350
)
351
parser.add_argument(
352
"--allow_tf32",
353
action="store_true",
354
help=(
355
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
356
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
357
),
358
)
359
parser.add_argument(
360
"--report_to",
361
type=str,
362
default="tensorboard",
363
help=(
364
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
365
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
366
),
367
)
368
parser.add_argument(
369
"--mixed_precision",
370
type=str,
371
default=None,
372
choices=["no", "fp16", "bf16"],
373
help=(
374
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
375
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
376
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
377
),
378
)
379
parser.add_argument(
380
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
381
)
382
parser.add_argument(
383
"--set_grads_to_none",
384
action="store_true",
385
help=(
386
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
387
" behaviors, so disable this argument if it causes any problems. More info:"
388
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
389
),
390
)
391
parser.add_argument(
392
"--dataset_name",
393
type=str,
394
default=None,
395
help=(
396
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
397
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
398
" or to a folder containing files that 🤗 Datasets can understand."
399
),
400
)
401
parser.add_argument(
402
"--dataset_config_name",
403
type=str,
404
default=None,
405
help="The config of the Dataset, leave as None if there's only one config.",
406
)
407
parser.add_argument(
408
"--train_data_dir",
409
type=str,
410
default=None,
411
help=(
412
"A folder containing the training data. Folder contents must follow the structure described in"
413
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
414
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
415
),
416
)
417
parser.add_argument(
418
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
419
)
420
parser.add_argument(
421
"--conditioning_image_column",
422
type=str,
423
default="conditioning_image",
424
help="The column of the dataset containing the controlnet conditioning image.",
425
)
426
parser.add_argument(
427
"--caption_column",
428
type=str,
429
default="text",
430
help="The column of the dataset containing a caption or a list of captions.",
431
)
432
parser.add_argument(
433
"--max_train_samples",
434
type=int,
435
default=None,
436
help=(
437
"For debugging purposes or quicker training, truncate the number of training examples to this "
438
"value if set."
439
),
440
)
441
parser.add_argument(
442
"--proportion_empty_prompts",
443
type=float,
444
default=0,
445
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
446
)
447
parser.add_argument(
448
"--validation_prompt",
449
type=str,
450
default=None,
451
nargs="+",
452
help=(
453
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
454
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
455
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
456
),
457
)
458
parser.add_argument(
459
"--validation_image",
460
type=str,
461
default=None,
462
nargs="+",
463
help=(
464
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
465
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
466
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
467
" `--validation_image` that will be used with all `--validation_prompt`s."
468
),
469
)
470
parser.add_argument(
471
"--num_validation_images",
472
type=int,
473
default=4,
474
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
475
)
476
parser.add_argument(
477
"--validation_steps",
478
type=int,
479
default=100,
480
help=(
481
"Run validation every X steps. Validation consists of running the prompt"
482
" `args.validation_prompt` multiple times: `args.num_validation_images`"
483
" and logging the images."
484
),
485
)
486
parser.add_argument(
487
"--tracker_project_name",
488
type=str,
489
default="train_controlnet",
490
required=True,
491
help=(
492
"The `project_name` argument passed to Accelerator.init_trackers for"
493
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
494
),
495
)
496
497
if input_args is not None:
498
args = parser.parse_args(input_args)
499
else:
500
args = parser.parse_args()
501
502
if args.dataset_name is None and args.train_data_dir is None:
503
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
504
505
if args.dataset_name is not None and args.train_data_dir is not None:
506
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
507
508
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
509
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
510
511
if args.validation_prompt is not None and args.validation_image is None:
512
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
513
514
if args.validation_prompt is None and args.validation_image is not None:
515
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
516
517
if (
518
args.validation_image is not None
519
and args.validation_prompt is not None
520
and len(args.validation_image) != 1
521
and len(args.validation_prompt) != 1
522
and len(args.validation_image) != len(args.validation_prompt)
523
):
524
raise ValueError(
525
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
526
" or the same number of `--validation_prompt`s and `--validation_image`s"
527
)
528
529
return args
530
531
532
def make_train_dataset(args, tokenizer, accelerator):
533
# Get the datasets: you can either provide your own training and evaluation files (see below)
534
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
535
536
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
537
# download the dataset.
538
if args.dataset_name is not None:
539
# Downloading and loading a dataset from the hub.
540
dataset = load_dataset(
541
args.dataset_name,
542
args.dataset_config_name,
543
cache_dir=args.cache_dir,
544
)
545
else:
546
data_files = {}
547
if args.train_data_dir is not None:
548
data_files["train"] = os.path.join(args.train_data_dir, "**")
549
dataset = load_dataset(
550
"imagefolder",
551
data_files=data_files,
552
cache_dir=args.cache_dir,
553
)
554
# See more about loading custom images at
555
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
556
557
# Preprocessing the datasets.
558
# We need to tokenize inputs and targets.
559
column_names = dataset["train"].column_names
560
561
# 6. Get the column names for input/target.
562
if args.image_column is None:
563
image_column = column_names[0]
564
logger.info(f"image column defaulting to {image_column}")
565
else:
566
image_column = args.image_column
567
if image_column not in column_names:
568
raise ValueError(
569
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
570
)
571
572
if args.caption_column is None:
573
caption_column = column_names[1]
574
logger.info(f"caption column defaulting to {caption_column}")
575
else:
576
caption_column = args.caption_column
577
if caption_column not in column_names:
578
raise ValueError(
579
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
580
)
581
582
if args.conditioning_image_column is None:
583
conditioning_image_column = column_names[2]
584
logger.info(f"conditioning image column defaulting to {caption_column}")
585
else:
586
conditioning_image_column = args.conditioning_image_column
587
if conditioning_image_column not in column_names:
588
raise ValueError(
589
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
590
)
591
592
def tokenize_captions(examples, is_train=True):
593
captions = []
594
for caption in examples[caption_column]:
595
if random.random() < args.proportion_empty_prompts:
596
captions.append("")
597
elif isinstance(caption, str):
598
captions.append(caption)
599
elif isinstance(caption, (list, np.ndarray)):
600
# take a random caption if there are multiple
601
captions.append(random.choice(caption) if is_train else caption[0])
602
else:
603
raise ValueError(
604
f"Caption column `{caption_column}` should contain either strings or lists of strings."
605
)
606
inputs = tokenizer(
607
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
608
)
609
return inputs.input_ids
610
611
image_transforms = transforms.Compose(
612
[
613
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
614
transforms.ToTensor(),
615
transforms.Normalize([0.5], [0.5]),
616
]
617
)
618
619
conditioning_image_transforms = transforms.Compose(
620
[
621
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
622
transforms.ToTensor(),
623
]
624
)
625
626
def preprocess_train(examples):
627
images = [image.convert("RGB") for image in examples[image_column]]
628
images = [image_transforms(image) for image in images]
629
630
conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
631
conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
632
633
examples["pixel_values"] = images
634
examples["conditioning_pixel_values"] = conditioning_images
635
examples["input_ids"] = tokenize_captions(examples)
636
637
return examples
638
639
with accelerator.main_process_first():
640
if args.max_train_samples is not None:
641
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
642
# Set the training transforms
643
train_dataset = dataset["train"].with_transform(preprocess_train)
644
645
return train_dataset
646
647
648
def collate_fn(examples):
649
pixel_values = torch.stack([example["pixel_values"] for example in examples])
650
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
651
652
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
653
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
654
655
input_ids = torch.stack([example["input_ids"] for example in examples])
656
657
return {
658
"pixel_values": pixel_values,
659
"conditioning_pixel_values": conditioning_pixel_values,
660
"input_ids": input_ids,
661
}
662
663
664
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
665
if token is None:
666
token = HfFolder.get_token()
667
if organization is None:
668
username = whoami(token)["name"]
669
return f"{username}/{model_id}"
670
else:
671
return f"{organization}/{model_id}"
672
673
674
def main(args):
675
logging_dir = Path(args.output_dir, args.logging_dir)
676
677
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
678
679
accelerator = Accelerator(
680
gradient_accumulation_steps=args.gradient_accumulation_steps,
681
mixed_precision=args.mixed_precision,
682
log_with=args.report_to,
683
logging_dir=logging_dir,
684
project_config=accelerator_project_config,
685
)
686
687
# Make one log on every process with the configuration for debugging.
688
logging.basicConfig(
689
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
690
datefmt="%m/%d/%Y %H:%M:%S",
691
level=logging.INFO,
692
)
693
logger.info(accelerator.state, main_process_only=False)
694
if accelerator.is_local_main_process:
695
transformers.utils.logging.set_verbosity_warning()
696
diffusers.utils.logging.set_verbosity_info()
697
else:
698
transformers.utils.logging.set_verbosity_error()
699
diffusers.utils.logging.set_verbosity_error()
700
701
# If passed along, set the training seed now.
702
if args.seed is not None:
703
set_seed(args.seed)
704
705
# Handle the repository creation
706
if accelerator.is_main_process:
707
if args.push_to_hub:
708
if args.hub_model_id is None:
709
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
710
else:
711
repo_name = args.hub_model_id
712
create_repo(repo_name, exist_ok=True, token=args.hub_token)
713
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
714
715
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
716
if "step_*" not in gitignore:
717
gitignore.write("step_*\n")
718
if "epoch_*" not in gitignore:
719
gitignore.write("epoch_*\n")
720
elif args.output_dir is not None:
721
os.makedirs(args.output_dir, exist_ok=True)
722
723
# Load the tokenizer
724
if args.tokenizer_name:
725
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
726
elif args.pretrained_model_name_or_path:
727
tokenizer = AutoTokenizer.from_pretrained(
728
args.pretrained_model_name_or_path,
729
subfolder="tokenizer",
730
revision=args.revision,
731
use_fast=False,
732
)
733
734
# import correct text encoder class
735
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
736
737
# Load scheduler and models
738
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
739
text_encoder = text_encoder_cls.from_pretrained(
740
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
741
)
742
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
743
unet = UNet2DConditionModel.from_pretrained(
744
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
745
)
746
747
if args.controlnet_model_name_or_path:
748
logger.info("Loading existing controlnet weights")
749
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
750
else:
751
logger.info("Initializing controlnet weights from unet")
752
controlnet = ControlNetModel.from_unet(unet)
753
754
# `accelerate` 0.16.0 will have better support for customized saving
755
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
756
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
757
def save_model_hook(models, weights, output_dir):
758
i = len(weights) - 1
759
760
while len(weights) > 0:
761
weights.pop()
762
model = models[i]
763
764
sub_dir = "controlnet"
765
model.save_pretrained(os.path.join(output_dir, sub_dir))
766
767
i -= 1
768
769
def load_model_hook(models, input_dir):
770
while len(models) > 0:
771
# pop models so that they are not loaded again
772
model = models.pop()
773
774
# load diffusers style into model
775
load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
776
model.register_to_config(**load_model.config)
777
778
model.load_state_dict(load_model.state_dict())
779
del load_model
780
781
accelerator.register_save_state_pre_hook(save_model_hook)
782
accelerator.register_load_state_pre_hook(load_model_hook)
783
784
vae.requires_grad_(False)
785
unet.requires_grad_(False)
786
text_encoder.requires_grad_(False)
787
controlnet.train()
788
789
if args.enable_xformers_memory_efficient_attention:
790
if is_xformers_available():
791
import xformers
792
793
xformers_version = version.parse(xformers.__version__)
794
if xformers_version == version.parse("0.0.16"):
795
logger.warn(
796
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
797
)
798
unet.enable_xformers_memory_efficient_attention()
799
controlnet.enable_xformers_memory_efficient_attention()
800
else:
801
raise ValueError("xformers is not available. Make sure it is installed correctly")
802
803
if args.gradient_checkpointing:
804
controlnet.enable_gradient_checkpointing()
805
806
# Check that all trainable models are in full precision
807
low_precision_error_string = (
808
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
809
" doing mixed precision training, copy of the weights should still be float32."
810
)
811
812
if accelerator.unwrap_model(controlnet).dtype != torch.float32:
813
raise ValueError(
814
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
815
)
816
817
# Enable TF32 for faster training on Ampere GPUs,
818
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
819
if args.allow_tf32:
820
torch.backends.cuda.matmul.allow_tf32 = True
821
822
if args.scale_lr:
823
args.learning_rate = (
824
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
825
)
826
827
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
828
if args.use_8bit_adam:
829
try:
830
import bitsandbytes as bnb
831
except ImportError:
832
raise ImportError(
833
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
834
)
835
836
optimizer_class = bnb.optim.AdamW8bit
837
else:
838
optimizer_class = torch.optim.AdamW
839
840
# Optimizer creation
841
params_to_optimize = controlnet.parameters()
842
optimizer = optimizer_class(
843
params_to_optimize,
844
lr=args.learning_rate,
845
betas=(args.adam_beta1, args.adam_beta2),
846
weight_decay=args.adam_weight_decay,
847
eps=args.adam_epsilon,
848
)
849
850
train_dataset = make_train_dataset(args, tokenizer, accelerator)
851
852
train_dataloader = torch.utils.data.DataLoader(
853
train_dataset,
854
shuffle=True,
855
collate_fn=collate_fn,
856
batch_size=args.train_batch_size,
857
num_workers=args.dataloader_num_workers,
858
)
859
860
# Scheduler and math around the number of training steps.
861
overrode_max_train_steps = False
862
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
863
if args.max_train_steps is None:
864
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
865
overrode_max_train_steps = True
866
867
lr_scheduler = get_scheduler(
868
args.lr_scheduler,
869
optimizer=optimizer,
870
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
871
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
872
num_cycles=args.lr_num_cycles,
873
power=args.lr_power,
874
)
875
876
# Prepare everything with our `accelerator`.
877
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
878
controlnet, optimizer, train_dataloader, lr_scheduler
879
)
880
881
# For mixed precision training we cast the text_encoder and vae weights to half-precision
882
# as these models are only used for inference, keeping weights in full precision is not required.
883
weight_dtype = torch.float32
884
if accelerator.mixed_precision == "fp16":
885
weight_dtype = torch.float16
886
elif accelerator.mixed_precision == "bf16":
887
weight_dtype = torch.bfloat16
888
889
# Move vae, unet and text_encoder to device and cast to weight_dtype
890
vae.to(accelerator.device, dtype=weight_dtype)
891
unet.to(accelerator.device, dtype=weight_dtype)
892
text_encoder.to(accelerator.device, dtype=weight_dtype)
893
894
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
895
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
896
if overrode_max_train_steps:
897
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
898
# Afterwards we recalculate our number of training epochs
899
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
900
901
# We need to initialize the trackers we use, and also store our configuration.
902
# The trackers initializes automatically on the main process.
903
if accelerator.is_main_process:
904
tracker_config = dict(vars(args))
905
906
# tensorboard cannot handle list types for config
907
tracker_config.pop("validation_prompt")
908
tracker_config.pop("validation_image")
909
910
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
911
912
# Train!
913
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
914
915
logger.info("***** Running training *****")
916
logger.info(f" Num examples = {len(train_dataset)}")
917
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
918
logger.info(f" Num Epochs = {args.num_train_epochs}")
919
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
920
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
921
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
922
logger.info(f" Total optimization steps = {args.max_train_steps}")
923
global_step = 0
924
first_epoch = 0
925
926
# Potentially load in the weights and states from a previous save
927
if args.resume_from_checkpoint:
928
if args.resume_from_checkpoint != "latest":
929
path = os.path.basename(args.resume_from_checkpoint)
930
else:
931
# Get the most recent checkpoint
932
dirs = os.listdir(args.output_dir)
933
dirs = [d for d in dirs if d.startswith("checkpoint")]
934
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
935
path = dirs[-1] if len(dirs) > 0 else None
936
937
if path is None:
938
accelerator.print(
939
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
940
)
941
args.resume_from_checkpoint = None
942
initial_global_step = 0
943
else:
944
accelerator.print(f"Resuming from checkpoint {path}")
945
accelerator.load_state(os.path.join(args.output_dir, path))
946
global_step = int(path.split("-")[1])
947
948
initial_global_step = global_step * args.gradient_accumulation_steps
949
first_epoch = global_step // num_update_steps_per_epoch
950
else:
951
initial_global_step = 0
952
953
progress_bar = tqdm(
954
range(0, args.max_train_steps),
955
initial=initial_global_step,
956
desc="Steps",
957
# Only show the progress bar once on each machine.
958
disable=not accelerator.is_local_main_process,
959
)
960
961
for epoch in range(first_epoch, args.num_train_epochs):
962
for step, batch in enumerate(train_dataloader):
963
with accelerator.accumulate(controlnet):
964
# Convert images to latent space
965
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
966
latents = latents * vae.config.scaling_factor
967
968
# Sample noise that we'll add to the latents
969
noise = torch.randn_like(latents)
970
bsz = latents.shape[0]
971
# Sample a random timestep for each image
972
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
973
timesteps = timesteps.long()
974
975
# Add noise to the latents according to the noise magnitude at each timestep
976
# (this is the forward diffusion process)
977
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
978
979
# Get the text embedding for conditioning
980
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
981
982
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
983
984
down_block_res_samples, mid_block_res_sample = controlnet(
985
noisy_latents,
986
timesteps,
987
encoder_hidden_states=encoder_hidden_states,
988
controlnet_cond=controlnet_image,
989
return_dict=False,
990
)
991
992
# Predict the noise residual
993
model_pred = unet(
994
noisy_latents,
995
timesteps,
996
encoder_hidden_states=encoder_hidden_states,
997
down_block_additional_residuals=down_block_res_samples,
998
mid_block_additional_residual=mid_block_res_sample,
999
).sample
1000
1001
# Get the target for loss depending on the prediction type
1002
if noise_scheduler.config.prediction_type == "epsilon":
1003
target = noise
1004
elif noise_scheduler.config.prediction_type == "v_prediction":
1005
target = noise_scheduler.get_velocity(latents, noise, timesteps)
1006
else:
1007
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1008
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1009
1010
accelerator.backward(loss)
1011
if accelerator.sync_gradients:
1012
params_to_clip = controlnet.parameters()
1013
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1014
optimizer.step()
1015
lr_scheduler.step()
1016
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1017
1018
# Checks if the accelerator has performed an optimization step behind the scenes
1019
if accelerator.sync_gradients:
1020
progress_bar.update(1)
1021
global_step += 1
1022
1023
if accelerator.is_main_process:
1024
if global_step % args.checkpointing_steps == 0:
1025
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1026
accelerator.save_state(save_path)
1027
logger.info(f"Saved state to {save_path}")
1028
1029
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1030
log_validation(
1031
vae,
1032
text_encoder,
1033
tokenizer,
1034
unet,
1035
controlnet,
1036
args,
1037
accelerator,
1038
weight_dtype,
1039
global_step,
1040
)
1041
1042
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1043
progress_bar.set_postfix(**logs)
1044
accelerator.log(logs, step=global_step)
1045
1046
if global_step >= args.max_train_steps:
1047
break
1048
1049
# Create the pipeline using using the trained modules and save it.
1050
accelerator.wait_for_everyone()
1051
if accelerator.is_main_process:
1052
controlnet = accelerator.unwrap_model(controlnet)
1053
controlnet.save_pretrained(args.output_dir)
1054
1055
if args.push_to_hub:
1056
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1057
1058
accelerator.end_training()
1059
1060
1061
if __name__ == "__main__":
1062
args = parse_args()
1063
main(args)
1064
1065