Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
1448 views
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
17
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
18
19
import argparse
20
import logging
21
import math
22
import os
23
from pathlib import Path
24
from typing import Optional
25
26
import accelerate
27
import datasets
28
import numpy as np
29
import PIL
30
import requests
31
import torch
32
import torch.nn as nn
33
import torch.nn.functional as F
34
import torch.utils.checkpoint
35
import transformers
36
from accelerate import Accelerator
37
from accelerate.logging import get_logger
38
from accelerate.utils import ProjectConfiguration, set_seed
39
from datasets import load_dataset
40
from huggingface_hub import HfFolder, Repository, create_repo, whoami
41
from packaging import version
42
from torchvision import transforms
43
from tqdm.auto import tqdm
44
from transformers import CLIPTextModel, CLIPTokenizer
45
46
import diffusers
47
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
48
from diffusers.optimization import get_scheduler
49
from diffusers.training_utils import EMAModel
50
from diffusers.utils import check_min_version, deprecate, is_wandb_available
51
from diffusers.utils.import_utils import is_xformers_available
52
53
54
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
check_min_version("0.15.0.dev0")
56
57
logger = get_logger(__name__, log_level="INFO")
58
59
DATASET_NAME_MAPPING = {
60
"fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),
61
}
62
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
63
64
65
def parse_args():
66
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
67
parser.add_argument(
68
"--pretrained_model_name_or_path",
69
type=str,
70
default=None,
71
required=True,
72
help="Path to pretrained model or model identifier from huggingface.co/models.",
73
)
74
parser.add_argument(
75
"--revision",
76
type=str,
77
default=None,
78
required=False,
79
help="Revision of pretrained model identifier from huggingface.co/models.",
80
)
81
parser.add_argument(
82
"--dataset_name",
83
type=str,
84
default=None,
85
help=(
86
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
87
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
88
" or to a folder containing files that 🤗 Datasets can understand."
89
),
90
)
91
parser.add_argument(
92
"--dataset_config_name",
93
type=str,
94
default=None,
95
help="The config of the Dataset, leave as None if there's only one config.",
96
)
97
parser.add_argument(
98
"--train_data_dir",
99
type=str,
100
default=None,
101
help=(
102
"A folder containing the training data. Folder contents must follow the structure described in"
103
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
104
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
105
),
106
)
107
parser.add_argument(
108
"--original_image_column",
109
type=str,
110
default="input_image",
111
help="The column of the dataset containing the original image on which edits where made.",
112
)
113
parser.add_argument(
114
"--edited_image_column",
115
type=str,
116
default="edited_image",
117
help="The column of the dataset containing the edited image.",
118
)
119
parser.add_argument(
120
"--edit_prompt_column",
121
type=str,
122
default="edit_prompt",
123
help="The column of the dataset containing the edit instruction.",
124
)
125
parser.add_argument(
126
"--val_image_url",
127
type=str,
128
default=None,
129
help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
130
)
131
parser.add_argument(
132
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
133
)
134
parser.add_argument(
135
"--num_validation_images",
136
type=int,
137
default=4,
138
help="Number of images that should be generated during validation with `validation_prompt`.",
139
)
140
parser.add_argument(
141
"--validation_epochs",
142
type=int,
143
default=1,
144
help=(
145
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
146
" `args.validation_prompt` multiple times: `args.num_validation_images`."
147
),
148
)
149
parser.add_argument(
150
"--max_train_samples",
151
type=int,
152
default=None,
153
help=(
154
"For debugging purposes or quicker training, truncate the number of training examples to this "
155
"value if set."
156
),
157
)
158
parser.add_argument(
159
"--output_dir",
160
type=str,
161
default="instruct-pix2pix-model",
162
help="The output directory where the model predictions and checkpoints will be written.",
163
)
164
parser.add_argument(
165
"--cache_dir",
166
type=str,
167
default=None,
168
help="The directory where the downloaded models and datasets will be stored.",
169
)
170
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
171
parser.add_argument(
172
"--resolution",
173
type=int,
174
default=256,
175
help=(
176
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
177
" resolution"
178
),
179
)
180
parser.add_argument(
181
"--center_crop",
182
default=False,
183
action="store_true",
184
help=(
185
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
186
" cropped. The images will be resized to the resolution first before cropping."
187
),
188
)
189
parser.add_argument(
190
"--random_flip",
191
action="store_true",
192
help="whether to randomly flip images horizontally",
193
)
194
parser.add_argument(
195
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
196
)
197
parser.add_argument("--num_train_epochs", type=int, default=100)
198
parser.add_argument(
199
"--max_train_steps",
200
type=int,
201
default=None,
202
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
203
)
204
parser.add_argument(
205
"--gradient_accumulation_steps",
206
type=int,
207
default=1,
208
help="Number of updates steps to accumulate before performing a backward/update pass.",
209
)
210
parser.add_argument(
211
"--gradient_checkpointing",
212
action="store_true",
213
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
214
)
215
parser.add_argument(
216
"--learning_rate",
217
type=float,
218
default=1e-4,
219
help="Initial learning rate (after the potential warmup period) to use.",
220
)
221
parser.add_argument(
222
"--scale_lr",
223
action="store_true",
224
default=False,
225
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
226
)
227
parser.add_argument(
228
"--lr_scheduler",
229
type=str,
230
default="constant",
231
help=(
232
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
233
' "constant", "constant_with_warmup"]'
234
),
235
)
236
parser.add_argument(
237
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
238
)
239
parser.add_argument(
240
"--conditioning_dropout_prob",
241
type=float,
242
default=None,
243
help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
244
)
245
parser.add_argument(
246
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
247
)
248
parser.add_argument(
249
"--allow_tf32",
250
action="store_true",
251
help=(
252
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
253
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
254
),
255
)
256
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
257
parser.add_argument(
258
"--non_ema_revision",
259
type=str,
260
default=None,
261
required=False,
262
help=(
263
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
264
" remote repository specified with --pretrained_model_name_or_path."
265
),
266
)
267
parser.add_argument(
268
"--dataloader_num_workers",
269
type=int,
270
default=0,
271
help=(
272
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
273
),
274
)
275
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
276
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
277
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
278
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
279
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
280
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
281
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
282
parser.add_argument(
283
"--hub_model_id",
284
type=str,
285
default=None,
286
help="The name of the repository to keep in sync with the local `output_dir`.",
287
)
288
parser.add_argument(
289
"--logging_dir",
290
type=str,
291
default="logs",
292
help=(
293
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
294
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
295
),
296
)
297
parser.add_argument(
298
"--mixed_precision",
299
type=str,
300
default=None,
301
choices=["no", "fp16", "bf16"],
302
help=(
303
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
304
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
305
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
306
),
307
)
308
parser.add_argument(
309
"--report_to",
310
type=str,
311
default="tensorboard",
312
help=(
313
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
314
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
315
),
316
)
317
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
318
parser.add_argument(
319
"--checkpointing_steps",
320
type=int,
321
default=500,
322
help=(
323
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
324
" training using `--resume_from_checkpoint`."
325
),
326
)
327
parser.add_argument(
328
"--checkpoints_total_limit",
329
type=int,
330
default=None,
331
help=(
332
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
333
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
334
" for more docs"
335
),
336
)
337
parser.add_argument(
338
"--resume_from_checkpoint",
339
type=str,
340
default=None,
341
help=(
342
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
343
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
344
),
345
)
346
parser.add_argument(
347
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
348
)
349
350
args = parser.parse_args()
351
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
352
if env_local_rank != -1 and env_local_rank != args.local_rank:
353
args.local_rank = env_local_rank
354
355
# Sanity checks
356
if args.dataset_name is None and args.train_data_dir is None:
357
raise ValueError("Need either a dataset name or a training folder.")
358
359
# default to using the same revision for the non-ema model if not specified
360
if args.non_ema_revision is None:
361
args.non_ema_revision = args.revision
362
363
return args
364
365
366
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
367
if token is None:
368
token = HfFolder.get_token()
369
if organization is None:
370
username = whoami(token)["name"]
371
return f"{username}/{model_id}"
372
else:
373
return f"{organization}/{model_id}"
374
375
376
def convert_to_np(image, resolution):
377
image = image.convert("RGB").resize((resolution, resolution))
378
return np.array(image).transpose(2, 0, 1)
379
380
381
def download_image(url):
382
image = PIL.Image.open(requests.get(url, stream=True).raw)
383
image = PIL.ImageOps.exif_transpose(image)
384
image = image.convert("RGB")
385
return image
386
387
388
def main():
389
args = parse_args()
390
391
if args.non_ema_revision is not None:
392
deprecate(
393
"non_ema_revision!=None",
394
"0.15.0",
395
message=(
396
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
397
" use `--variant=non_ema` instead."
398
),
399
)
400
logging_dir = os.path.join(args.output_dir, args.logging_dir)
401
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
402
accelerator = Accelerator(
403
gradient_accumulation_steps=args.gradient_accumulation_steps,
404
mixed_precision=args.mixed_precision,
405
log_with=args.report_to,
406
logging_dir=logging_dir,
407
project_config=accelerator_project_config,
408
)
409
410
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
411
412
if args.report_to == "wandb":
413
if not is_wandb_available():
414
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
415
import wandb
416
417
# Make one log on every process with the configuration for debugging.
418
logging.basicConfig(
419
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
420
datefmt="%m/%d/%Y %H:%M:%S",
421
level=logging.INFO,
422
)
423
logger.info(accelerator.state, main_process_only=False)
424
if accelerator.is_local_main_process:
425
datasets.utils.logging.set_verbosity_warning()
426
transformers.utils.logging.set_verbosity_warning()
427
diffusers.utils.logging.set_verbosity_info()
428
else:
429
datasets.utils.logging.set_verbosity_error()
430
transformers.utils.logging.set_verbosity_error()
431
diffusers.utils.logging.set_verbosity_error()
432
433
# If passed along, set the training seed now.
434
if args.seed is not None:
435
set_seed(args.seed)
436
437
# Handle the repository creation
438
if accelerator.is_main_process:
439
if args.push_to_hub:
440
if args.hub_model_id is None:
441
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
442
else:
443
repo_name = args.hub_model_id
444
create_repo(repo_name, exist_ok=True, token=args.hub_token)
445
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
446
447
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
448
if "step_*" not in gitignore:
449
gitignore.write("step_*\n")
450
if "epoch_*" not in gitignore:
451
gitignore.write("epoch_*\n")
452
elif args.output_dir is not None:
453
os.makedirs(args.output_dir, exist_ok=True)
454
455
# Load scheduler, tokenizer and models.
456
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
457
tokenizer = CLIPTokenizer.from_pretrained(
458
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
459
)
460
text_encoder = CLIPTextModel.from_pretrained(
461
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
462
)
463
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
464
unet = UNet2DConditionModel.from_pretrained(
465
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
466
)
467
468
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
469
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
470
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
471
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
472
# initialized to zero.
473
if accelerator.is_main_process:
474
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
475
in_channels = 8
476
out_channels = unet.conv_in.out_channels
477
unet.register_to_config(in_channels=in_channels)
478
479
with torch.no_grad():
480
new_conv_in = nn.Conv2d(
481
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
482
)
483
new_conv_in.weight.zero_()
484
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
485
unet.conv_in = new_conv_in
486
487
# Freeze vae and text_encoder
488
vae.requires_grad_(False)
489
text_encoder.requires_grad_(False)
490
491
# Create EMA for the unet.
492
if args.use_ema:
493
ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
494
495
if args.enable_xformers_memory_efficient_attention:
496
if is_xformers_available():
497
import xformers
498
499
xformers_version = version.parse(xformers.__version__)
500
if xformers_version == version.parse("0.0.16"):
501
logger.warn(
502
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
503
)
504
unet.enable_xformers_memory_efficient_attention()
505
else:
506
raise ValueError("xformers is not available. Make sure it is installed correctly")
507
508
# `accelerate` 0.16.0 will have better support for customized saving
509
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
510
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
511
def save_model_hook(models, weights, output_dir):
512
if args.use_ema:
513
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
514
515
for i, model in enumerate(models):
516
model.save_pretrained(os.path.join(output_dir, "unet"))
517
518
# make sure to pop weight so that corresponding model is not saved again
519
weights.pop()
520
521
def load_model_hook(models, input_dir):
522
if args.use_ema:
523
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
524
ema_unet.load_state_dict(load_model.state_dict())
525
ema_unet.to(accelerator.device)
526
del load_model
527
528
for i in range(len(models)):
529
# pop models so that they are not loaded again
530
model = models.pop()
531
532
# load diffusers style into model
533
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
534
model.register_to_config(**load_model.config)
535
536
model.load_state_dict(load_model.state_dict())
537
del load_model
538
539
accelerator.register_save_state_pre_hook(save_model_hook)
540
accelerator.register_load_state_pre_hook(load_model_hook)
541
542
if args.gradient_checkpointing:
543
unet.enable_gradient_checkpointing()
544
545
# Enable TF32 for faster training on Ampere GPUs,
546
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
547
if args.allow_tf32:
548
torch.backends.cuda.matmul.allow_tf32 = True
549
550
if args.scale_lr:
551
args.learning_rate = (
552
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
553
)
554
555
# Initialize the optimizer
556
if args.use_8bit_adam:
557
try:
558
import bitsandbytes as bnb
559
except ImportError:
560
raise ImportError(
561
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
562
)
563
564
optimizer_cls = bnb.optim.AdamW8bit
565
else:
566
optimizer_cls = torch.optim.AdamW
567
568
optimizer = optimizer_cls(
569
unet.parameters(),
570
lr=args.learning_rate,
571
betas=(args.adam_beta1, args.adam_beta2),
572
weight_decay=args.adam_weight_decay,
573
eps=args.adam_epsilon,
574
)
575
576
# Get the datasets: you can either provide your own training and evaluation files (see below)
577
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
578
579
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
580
# download the dataset.
581
if args.dataset_name is not None:
582
# Downloading and loading a dataset from the hub.
583
dataset = load_dataset(
584
args.dataset_name,
585
args.dataset_config_name,
586
cache_dir=args.cache_dir,
587
)
588
else:
589
data_files = {}
590
if args.train_data_dir is not None:
591
data_files["train"] = os.path.join(args.train_data_dir, "**")
592
dataset = load_dataset(
593
"imagefolder",
594
data_files=data_files,
595
cache_dir=args.cache_dir,
596
)
597
# See more about loading custom images at
598
# https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
599
600
# Preprocessing the datasets.
601
# We need to tokenize inputs and targets.
602
column_names = dataset["train"].column_names
603
604
# 6. Get the column names for input/target.
605
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
606
if args.original_image_column is None:
607
original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
608
else:
609
original_image_column = args.original_image_column
610
if original_image_column not in column_names:
611
raise ValueError(
612
f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
613
)
614
if args.edit_prompt_column is None:
615
edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
616
else:
617
edit_prompt_column = args.edit_prompt_column
618
if edit_prompt_column not in column_names:
619
raise ValueError(
620
f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
621
)
622
if args.edited_image_column is None:
623
edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
624
else:
625
edited_image_column = args.edited_image_column
626
if edited_image_column not in column_names:
627
raise ValueError(
628
f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
629
)
630
631
# Preprocessing the datasets.
632
# We need to tokenize input captions and transform the images.
633
def tokenize_captions(captions):
634
inputs = tokenizer(
635
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
636
)
637
return inputs.input_ids
638
639
# Preprocessing the datasets.
640
train_transforms = transforms.Compose(
641
[
642
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
643
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
644
]
645
)
646
647
def preprocess_images(examples):
648
original_images = np.concatenate(
649
[convert_to_np(image, args.resolution) for image in examples[original_image_column]]
650
)
651
edited_images = np.concatenate(
652
[convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
653
)
654
# We need to ensure that the original and the edited images undergo the same
655
# augmentation transforms.
656
images = np.concatenate([original_images, edited_images])
657
images = torch.tensor(images)
658
images = 2 * (images / 255) - 1
659
return train_transforms(images)
660
661
def preprocess_train(examples):
662
# Preprocess images.
663
preprocessed_images = preprocess_images(examples)
664
# Since the original and edited images were concatenated before
665
# applying the transformations, we need to separate them and reshape
666
# them accordingly.
667
original_images, edited_images = preprocessed_images.chunk(2)
668
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
669
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
670
671
# Collate the preprocessed images into the `examples`.
672
examples["original_pixel_values"] = original_images
673
examples["edited_pixel_values"] = edited_images
674
675
# Preprocess the captions.
676
captions = [caption for caption in examples[edit_prompt_column]]
677
examples["input_ids"] = tokenize_captions(captions)
678
return examples
679
680
with accelerator.main_process_first():
681
if args.max_train_samples is not None:
682
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
683
# Set the training transforms
684
train_dataset = dataset["train"].with_transform(preprocess_train)
685
686
def collate_fn(examples):
687
original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
688
original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
689
edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
690
edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
691
input_ids = torch.stack([example["input_ids"] for example in examples])
692
return {
693
"original_pixel_values": original_pixel_values,
694
"edited_pixel_values": edited_pixel_values,
695
"input_ids": input_ids,
696
}
697
698
# DataLoaders creation:
699
train_dataloader = torch.utils.data.DataLoader(
700
train_dataset,
701
shuffle=True,
702
collate_fn=collate_fn,
703
batch_size=args.train_batch_size,
704
num_workers=args.dataloader_num_workers,
705
)
706
707
# Scheduler and math around the number of training steps.
708
overrode_max_train_steps = False
709
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
710
if args.max_train_steps is None:
711
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
712
overrode_max_train_steps = True
713
714
lr_scheduler = get_scheduler(
715
args.lr_scheduler,
716
optimizer=optimizer,
717
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
718
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
719
)
720
721
# Prepare everything with our `accelerator`.
722
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
723
unet, optimizer, train_dataloader, lr_scheduler
724
)
725
726
if args.use_ema:
727
ema_unet.to(accelerator.device)
728
729
# For mixed precision training we cast the text_encoder and vae weights to half-precision
730
# as these models are only used for inference, keeping weights in full precision is not required.
731
weight_dtype = torch.float32
732
if accelerator.mixed_precision == "fp16":
733
weight_dtype = torch.float16
734
elif accelerator.mixed_precision == "bf16":
735
weight_dtype = torch.bfloat16
736
737
# Move text_encode and vae to gpu and cast to weight_dtype
738
text_encoder.to(accelerator.device, dtype=weight_dtype)
739
vae.to(accelerator.device, dtype=weight_dtype)
740
741
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
742
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
743
if overrode_max_train_steps:
744
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
745
# Afterwards we recalculate our number of training epochs
746
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
747
748
# We need to initialize the trackers we use, and also store our configuration.
749
# The trackers initializes automatically on the main process.
750
if accelerator.is_main_process:
751
accelerator.init_trackers("instruct-pix2pix", config=vars(args))
752
753
# Train!
754
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
755
756
logger.info("***** Running training *****")
757
logger.info(f" Num examples = {len(train_dataset)}")
758
logger.info(f" Num Epochs = {args.num_train_epochs}")
759
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
760
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
761
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
762
logger.info(f" Total optimization steps = {args.max_train_steps}")
763
global_step = 0
764
first_epoch = 0
765
766
# Potentially load in the weights and states from a previous save
767
if args.resume_from_checkpoint:
768
if args.resume_from_checkpoint != "latest":
769
path = os.path.basename(args.resume_from_checkpoint)
770
else:
771
# Get the most recent checkpoint
772
dirs = os.listdir(args.output_dir)
773
dirs = [d for d in dirs if d.startswith("checkpoint")]
774
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
775
path = dirs[-1] if len(dirs) > 0 else None
776
777
if path is None:
778
accelerator.print(
779
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
780
)
781
args.resume_from_checkpoint = None
782
else:
783
accelerator.print(f"Resuming from checkpoint {path}")
784
accelerator.load_state(os.path.join(args.output_dir, path))
785
global_step = int(path.split("-")[1])
786
787
resume_global_step = global_step * args.gradient_accumulation_steps
788
first_epoch = global_step // num_update_steps_per_epoch
789
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
790
791
# Only show the progress bar once on each machine.
792
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
793
progress_bar.set_description("Steps")
794
795
for epoch in range(first_epoch, args.num_train_epochs):
796
unet.train()
797
train_loss = 0.0
798
for step, batch in enumerate(train_dataloader):
799
# Skip steps until we reach the resumed step
800
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
801
if step % args.gradient_accumulation_steps == 0:
802
progress_bar.update(1)
803
continue
804
805
with accelerator.accumulate(unet):
806
# We want to learn the denoising process w.r.t the edited images which
807
# are conditioned on the original image (which was edited) and the edit instruction.
808
# So, first, convert images to latent space.
809
latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
810
latents = latents * vae.config.scaling_factor
811
812
# Sample noise that we'll add to the latents
813
noise = torch.randn_like(latents)
814
bsz = latents.shape[0]
815
# Sample a random timestep for each image
816
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
817
timesteps = timesteps.long()
818
819
# Add noise to the latents according to the noise magnitude at each timestep
820
# (this is the forward diffusion process)
821
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
822
823
# Get the text embedding for conditioning.
824
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
825
826
# Get the additional image embedding for conditioning.
827
# Instead of getting a diagonal Gaussian here, we simply take the mode.
828
original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
829
830
# Conditioning dropout to support classifier-free guidance during inference. For more details
831
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
832
if args.conditioning_dropout_prob is not None:
833
random_p = torch.rand(bsz, device=latents.device, generator=generator)
834
# Sample masks for the edit prompts.
835
prompt_mask = random_p < 2 * args.conditioning_dropout_prob
836
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
837
# Final text conditioning.
838
null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
839
encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
840
841
# Sample masks for the original images.
842
image_mask_dtype = original_image_embeds.dtype
843
image_mask = 1 - (
844
(random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
845
* (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
846
)
847
image_mask = image_mask.reshape(bsz, 1, 1, 1)
848
# Final image conditioning.
849
original_image_embeds = image_mask * original_image_embeds
850
851
# Concatenate the `original_image_embeds` with the `noisy_latents`.
852
concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
853
854
# Get the target for loss depending on the prediction type
855
if noise_scheduler.config.prediction_type == "epsilon":
856
target = noise
857
elif noise_scheduler.config.prediction_type == "v_prediction":
858
target = noise_scheduler.get_velocity(latents, noise, timesteps)
859
else:
860
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
861
862
# Predict the noise residual and compute loss
863
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
864
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
865
866
# Gather the losses across all processes for logging (if we use distributed training).
867
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
868
train_loss += avg_loss.item() / args.gradient_accumulation_steps
869
870
# Backpropagate
871
accelerator.backward(loss)
872
if accelerator.sync_gradients:
873
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
874
optimizer.step()
875
lr_scheduler.step()
876
optimizer.zero_grad()
877
878
# Checks if the accelerator has performed an optimization step behind the scenes
879
if accelerator.sync_gradients:
880
if args.use_ema:
881
ema_unet.step(unet.parameters())
882
progress_bar.update(1)
883
global_step += 1
884
accelerator.log({"train_loss": train_loss}, step=global_step)
885
train_loss = 0.0
886
887
if global_step % args.checkpointing_steps == 0:
888
if accelerator.is_main_process:
889
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
890
accelerator.save_state(save_path)
891
logger.info(f"Saved state to {save_path}")
892
893
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
894
progress_bar.set_postfix(**logs)
895
896
if global_step >= args.max_train_steps:
897
break
898
899
if accelerator.is_main_process:
900
if (
901
(args.val_image_url is not None)
902
and (args.validation_prompt is not None)
903
and (epoch % args.validation_epochs == 0)
904
):
905
logger.info(
906
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
907
f" {args.validation_prompt}."
908
)
909
# create pipeline
910
if args.use_ema:
911
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
912
ema_unet.store(unet.parameters())
913
ema_unet.copy_to(unet.parameters())
914
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
915
args.pretrained_model_name_or_path,
916
unet=unet,
917
revision=args.revision,
918
torch_dtype=weight_dtype,
919
)
920
pipeline = pipeline.to(accelerator.device)
921
pipeline.set_progress_bar_config(disable=True)
922
923
# run inference
924
original_image = download_image(args.val_image_url)
925
edited_images = []
926
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):
927
for _ in range(args.num_validation_images):
928
edited_images.append(
929
pipeline(
930
args.validation_prompt,
931
image=original_image,
932
num_inference_steps=20,
933
image_guidance_scale=1.5,
934
guidance_scale=7,
935
generator=generator,
936
).images[0]
937
)
938
939
for tracker in accelerator.trackers:
940
if tracker.name == "wandb":
941
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
942
for edited_image in edited_images:
943
wandb_table.add_data(
944
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
945
)
946
tracker.log({"validation": wandb_table})
947
if args.use_ema:
948
# Switch back to the original UNet parameters.
949
ema_unet.restore(unet.parameters())
950
951
del pipeline
952
torch.cuda.empty_cache()
953
954
# Create the pipeline using the trained modules and save it.
955
accelerator.wait_for_everyone()
956
if accelerator.is_main_process:
957
unet = accelerator.unwrap_model(unet)
958
if args.use_ema:
959
ema_unet.copy_to(unet.parameters())
960
961
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
962
args.pretrained_model_name_or_path,
963
text_encoder=accelerator.unwrap_model(text_encoder),
964
vae=accelerator.unwrap_model(vae),
965
unet=unet,
966
revision=args.revision,
967
)
968
pipeline.save_pretrained(args.output_dir)
969
970
if args.push_to_hub:
971
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
972
973
if args.validation_prompt is not None:
974
edited_images = []
975
pipeline = pipeline.to(accelerator.device)
976
with torch.autocast(str(accelerator.device)):
977
for _ in range(args.num_validation_images):
978
edited_images.append(
979
pipeline(
980
args.validation_prompt,
981
image=original_image,
982
num_inference_steps=20,
983
image_guidance_scale=1.5,
984
guidance_scale=7,
985
generator=generator,
986
).images[0]
987
)
988
989
for tracker in accelerator.trackers:
990
if tracker.name == "wandb":
991
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
992
for edited_image in edited_images:
993
wandb_table.add_data(
994
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
995
)
996
tracker.log({"test": wandb_table})
997
998
accelerator.end_training()
999
1000
1001
if __name__ == "__main__":
1002
main()
1003
1004