Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ShivamShrirao
GitHub Repository: ShivamShrirao/diffusers
Path: blob/main/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
1980 views
1
import argparse
2
import inspect
3
import logging
4
import math
5
import os
6
from pathlib import Path
7
from typing import Optional
8
9
import datasets
10
import torch
11
import torch.nn.functional as F
12
from accelerate import Accelerator
13
from accelerate.logging import get_logger
14
from accelerate.utils import ProjectConfiguration
15
from datasets import load_dataset
16
from huggingface_hub import HfFolder, Repository, create_repo, whoami
17
from onnxruntime.training.ortmodule import ORTModule
18
from torchvision import transforms
19
from tqdm.auto import tqdm
20
21
import diffusers
22
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
23
from diffusers.optimization import get_scheduler
24
from diffusers.training_utils import EMAModel
25
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
26
27
28
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
29
check_min_version("0.13.0.dev0")
30
31
logger = get_logger(__name__, log_level="INFO")
32
33
34
def _extract_into_tensor(arr, timesteps, broadcast_shape):
35
"""
36
Extract values from a 1-D numpy array for a batch of indices.
37
:param arr: the 1-D numpy array.
38
:param timesteps: a tensor of indices into the array to extract.
39
:param broadcast_shape: a larger shape of K dimensions with the batch
40
dimension equal to the length of timesteps.
41
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
42
"""
43
if not isinstance(arr, torch.Tensor):
44
arr = torch.from_numpy(arr)
45
res = arr[timesteps].float().to(timesteps.device)
46
while len(res.shape) < len(broadcast_shape):
47
res = res[..., None]
48
return res.expand(broadcast_shape)
49
50
51
def parse_args():
52
parser = argparse.ArgumentParser(description="Simple example of a training script.")
53
parser.add_argument(
54
"--dataset_name",
55
type=str,
56
default=None,
57
help=(
58
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
59
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
60
" or to a folder containing files that HF Datasets can understand."
61
),
62
)
63
parser.add_argument(
64
"--dataset_config_name",
65
type=str,
66
default=None,
67
help="The config of the Dataset, leave as None if there's only one config.",
68
)
69
parser.add_argument(
70
"--train_data_dir",
71
type=str,
72
default=None,
73
help=(
74
"A folder containing the training data. Folder contents must follow the structure described in"
75
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
76
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
77
),
78
)
79
parser.add_argument(
80
"--output_dir",
81
type=str,
82
default="ddpm-model-64",
83
help="The output directory where the model predictions and checkpoints will be written.",
84
)
85
parser.add_argument("--overwrite_output_dir", action="store_true")
86
parser.add_argument(
87
"--cache_dir",
88
type=str,
89
default=None,
90
help="The directory where the downloaded models and datasets will be stored.",
91
)
92
parser.add_argument(
93
"--resolution",
94
type=int,
95
default=64,
96
help=(
97
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
98
" resolution"
99
),
100
)
101
parser.add_argument(
102
"--center_crop",
103
default=False,
104
action="store_true",
105
help=(
106
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
107
" cropped. The images will be resized to the resolution first before cropping."
108
),
109
)
110
parser.add_argument(
111
"--random_flip",
112
default=False,
113
action="store_true",
114
help="whether to randomly flip images horizontally",
115
)
116
parser.add_argument(
117
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
118
)
119
parser.add_argument(
120
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
121
)
122
parser.add_argument(
123
"--dataloader_num_workers",
124
type=int,
125
default=0,
126
help=(
127
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
128
" process."
129
),
130
)
131
parser.add_argument("--num_epochs", type=int, default=100)
132
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
133
parser.add_argument(
134
"--save_model_epochs", type=int, default=10, help="How often to save the model during training."
135
)
136
parser.add_argument(
137
"--gradient_accumulation_steps",
138
type=int,
139
default=1,
140
help="Number of updates steps to accumulate before performing a backward/update pass.",
141
)
142
parser.add_argument(
143
"--learning_rate",
144
type=float,
145
default=1e-4,
146
help="Initial learning rate (after the potential warmup period) to use.",
147
)
148
parser.add_argument(
149
"--lr_scheduler",
150
type=str,
151
default="cosine",
152
help=(
153
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
154
' "constant", "constant_with_warmup"]'
155
),
156
)
157
parser.add_argument(
158
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
159
)
160
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
161
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
162
parser.add_argument(
163
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
164
)
165
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
166
parser.add_argument(
167
"--use_ema",
168
action="store_true",
169
help="Whether to use Exponential Moving Average for the final model weights.",
170
)
171
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
172
parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
173
parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
174
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
175
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
176
parser.add_argument(
177
"--hub_model_id",
178
type=str,
179
default=None,
180
help="The name of the repository to keep in sync with the local `output_dir`.",
181
)
182
parser.add_argument(
183
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
184
)
185
parser.add_argument(
186
"--logger",
187
type=str,
188
default="tensorboard",
189
choices=["tensorboard", "wandb"],
190
help=(
191
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
192
" for experiment tracking and logging of model metrics and model checkpoints"
193
),
194
)
195
parser.add_argument(
196
"--logging_dir",
197
type=str,
198
default="logs",
199
help=(
200
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
201
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
202
),
203
)
204
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
205
parser.add_argument(
206
"--mixed_precision",
207
type=str,
208
default="no",
209
choices=["no", "fp16", "bf16"],
210
help=(
211
"Whether to use mixed precision. Choose"
212
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
213
"and an Nvidia Ampere GPU."
214
),
215
)
216
parser.add_argument(
217
"--prediction_type",
218
type=str,
219
default="epsilon",
220
choices=["epsilon", "sample"],
221
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
222
)
223
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
224
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
225
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
226
parser.add_argument(
227
"--checkpointing_steps",
228
type=int,
229
default=500,
230
help=(
231
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
232
" training using `--resume_from_checkpoint`."
233
),
234
)
235
parser.add_argument(
236
"--checkpoints_total_limit",
237
type=int,
238
default=None,
239
help=(
240
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
241
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
242
" for more docs"
243
),
244
)
245
parser.add_argument(
246
"--resume_from_checkpoint",
247
type=str,
248
default=None,
249
help=(
250
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
251
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
252
),
253
)
254
255
args = parser.parse_args()
256
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
257
if env_local_rank != -1 and env_local_rank != args.local_rank:
258
args.local_rank = env_local_rank
259
260
if args.dataset_name is None and args.train_data_dir is None:
261
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
262
263
return args
264
265
266
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
267
if token is None:
268
token = HfFolder.get_token()
269
if organization is None:
270
username = whoami(token)["name"]
271
return f"{username}/{model_id}"
272
else:
273
return f"{organization}/{model_id}"
274
275
276
def main(args):
277
logging_dir = os.path.join(args.output_dir, args.logging_dir)
278
279
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
280
281
accelerator = Accelerator(
282
gradient_accumulation_steps=args.gradient_accumulation_steps,
283
mixed_precision=args.mixed_precision,
284
log_with=args.logger,
285
logging_dir=logging_dir,
286
project_config=accelerator_project_config,
287
)
288
289
if args.logger == "tensorboard":
290
if not is_tensorboard_available():
291
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
292
293
elif args.logger == "wandb":
294
if not is_wandb_available():
295
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
296
import wandb
297
298
# Make one log on every process with the configuration for debugging.
299
logging.basicConfig(
300
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
301
datefmt="%m/%d/%Y %H:%M:%S",
302
level=logging.INFO,
303
)
304
logger.info(accelerator.state, main_process_only=False)
305
if accelerator.is_local_main_process:
306
datasets.utils.logging.set_verbosity_warning()
307
diffusers.utils.logging.set_verbosity_info()
308
else:
309
datasets.utils.logging.set_verbosity_error()
310
diffusers.utils.logging.set_verbosity_error()
311
312
# Handle the repository creation
313
if accelerator.is_main_process:
314
if args.push_to_hub:
315
if args.hub_model_id is None:
316
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
317
else:
318
repo_name = args.hub_model_id
319
create_repo(repo_name, exist_ok=True, token=args.hub_token)
320
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
321
322
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
323
if "step_*" not in gitignore:
324
gitignore.write("step_*\n")
325
if "epoch_*" not in gitignore:
326
gitignore.write("epoch_*\n")
327
elif args.output_dir is not None:
328
os.makedirs(args.output_dir, exist_ok=True)
329
330
# Initialize the model
331
model = UNet2DModel(
332
sample_size=args.resolution,
333
in_channels=3,
334
out_channels=3,
335
layers_per_block=2,
336
block_out_channels=(128, 128, 256, 256, 512, 512),
337
down_block_types=(
338
"DownBlock2D",
339
"DownBlock2D",
340
"DownBlock2D",
341
"DownBlock2D",
342
"AttnDownBlock2D",
343
"DownBlock2D",
344
),
345
up_block_types=(
346
"UpBlock2D",
347
"AttnUpBlock2D",
348
"UpBlock2D",
349
"UpBlock2D",
350
"UpBlock2D",
351
"UpBlock2D",
352
),
353
)
354
355
# Create EMA for the model.
356
if args.use_ema:
357
ema_model = EMAModel(
358
model.parameters(),
359
decay=args.ema_max_decay,
360
use_ema_warmup=True,
361
inv_gamma=args.ema_inv_gamma,
362
power=args.ema_power,
363
)
364
365
# Initialize the scheduler
366
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
367
if accepts_prediction_type:
368
noise_scheduler = DDPMScheduler(
369
num_train_timesteps=args.ddpm_num_steps,
370
beta_schedule=args.ddpm_beta_schedule,
371
prediction_type=args.prediction_type,
372
)
373
else:
374
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
375
376
# Initialize the optimizer
377
optimizer = torch.optim.AdamW(
378
model.parameters(),
379
lr=args.learning_rate,
380
betas=(args.adam_beta1, args.adam_beta2),
381
weight_decay=args.adam_weight_decay,
382
eps=args.adam_epsilon,
383
)
384
385
# Get the datasets: you can either provide your own training and evaluation files (see below)
386
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
387
388
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
389
# download the dataset.
390
if args.dataset_name is not None:
391
dataset = load_dataset(
392
args.dataset_name,
393
args.dataset_config_name,
394
cache_dir=args.cache_dir,
395
split="train",
396
)
397
else:
398
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
399
# See more about loading custom images at
400
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
401
402
# Preprocessing the datasets and DataLoaders creation.
403
augmentations = transforms.Compose(
404
[
405
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
406
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
407
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
408
transforms.ToTensor(),
409
transforms.Normalize([0.5], [0.5]),
410
]
411
)
412
413
def transform_images(examples):
414
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
415
return {"input": images}
416
417
logger.info(f"Dataset size: {len(dataset)}")
418
419
dataset.set_transform(transform_images)
420
train_dataloader = torch.utils.data.DataLoader(
421
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
422
)
423
424
# Initialize the learning rate scheduler
425
lr_scheduler = get_scheduler(
426
args.lr_scheduler,
427
optimizer=optimizer,
428
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
429
num_training_steps=(len(train_dataloader) * args.num_epochs),
430
)
431
432
# Prepare everything with our `accelerator`.
433
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
434
model, optimizer, train_dataloader, lr_scheduler
435
)
436
437
model = ORTModule(model)
438
439
if args.use_ema:
440
accelerator.register_for_checkpointing(ema_model)
441
ema_model.to(accelerator.device)
442
443
# We need to initialize the trackers we use, and also store our configuration.
444
# The trackers initializes automatically on the main process.
445
if accelerator.is_main_process:
446
run = os.path.split(__file__)[-1].split(".")[0]
447
accelerator.init_trackers(run)
448
449
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
450
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
451
max_train_steps = args.num_epochs * num_update_steps_per_epoch
452
453
logger.info("***** Running training *****")
454
logger.info(f" Num examples = {len(dataset)}")
455
logger.info(f" Num Epochs = {args.num_epochs}")
456
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
457
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
458
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
459
logger.info(f" Total optimization steps = {max_train_steps}")
460
461
global_step = 0
462
first_epoch = 0
463
464
# Potentially load in the weights and states from a previous save
465
if args.resume_from_checkpoint:
466
if args.resume_from_checkpoint != "latest":
467
path = os.path.basename(args.resume_from_checkpoint)
468
else:
469
# Get the most recent checkpoint
470
dirs = os.listdir(args.output_dir)
471
dirs = [d for d in dirs if d.startswith("checkpoint")]
472
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
473
path = dirs[-1] if len(dirs) > 0 else None
474
475
if path is None:
476
accelerator.print(
477
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
478
)
479
args.resume_from_checkpoint = None
480
else:
481
accelerator.print(f"Resuming from checkpoint {path}")
482
accelerator.load_state(os.path.join(args.output_dir, path))
483
global_step = int(path.split("-")[1])
484
485
resume_global_step = global_step * args.gradient_accumulation_steps
486
first_epoch = global_step // num_update_steps_per_epoch
487
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
488
489
# Train!
490
for epoch in range(first_epoch, args.num_epochs):
491
model.train()
492
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
493
progress_bar.set_description(f"Epoch {epoch}")
494
for step, batch in enumerate(train_dataloader):
495
# Skip steps until we reach the resumed step
496
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
497
if step % args.gradient_accumulation_steps == 0:
498
progress_bar.update(1)
499
continue
500
501
clean_images = batch["input"]
502
# Sample noise that we'll add to the images
503
noise = torch.randn(clean_images.shape).to(clean_images.device)
504
bsz = clean_images.shape[0]
505
# Sample a random timestep for each image
506
timesteps = torch.randint(
507
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
508
).long()
509
510
# Add noise to the clean images according to the noise magnitude at each timestep
511
# (this is the forward diffusion process)
512
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
513
514
with accelerator.accumulate(model):
515
# Predict the noise residual
516
model_output = model(noisy_images, timesteps, return_dict=False)[0]
517
518
if args.prediction_type == "epsilon":
519
loss = F.mse_loss(model_output, noise) # this could have different weights!
520
elif args.prediction_type == "sample":
521
alpha_t = _extract_into_tensor(
522
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
523
)
524
snr_weights = alpha_t / (1 - alpha_t)
525
loss = snr_weights * F.mse_loss(
526
model_output, clean_images, reduction="none"
527
) # use SNR weighting from distillation paper
528
loss = loss.mean()
529
else:
530
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
531
532
accelerator.backward(loss)
533
534
if accelerator.sync_gradients:
535
accelerator.clip_grad_norm_(model.parameters(), 1.0)
536
optimizer.step()
537
lr_scheduler.step()
538
optimizer.zero_grad()
539
540
# Checks if the accelerator has performed an optimization step behind the scenes
541
if accelerator.sync_gradients:
542
if args.use_ema:
543
ema_model.step(model.parameters())
544
progress_bar.update(1)
545
global_step += 1
546
547
if global_step % args.checkpointing_steps == 0:
548
if accelerator.is_main_process:
549
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
550
accelerator.save_state(save_path)
551
logger.info(f"Saved state to {save_path}")
552
553
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
554
if args.use_ema:
555
logs["ema_decay"] = ema_model.decay
556
progress_bar.set_postfix(**logs)
557
accelerator.log(logs, step=global_step)
558
progress_bar.close()
559
560
accelerator.wait_for_everyone()
561
562
# Generate sample images for visual inspection
563
if accelerator.is_main_process:
564
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
565
unet = accelerator.unwrap_model(model)
566
if args.use_ema:
567
ema_model.copy_to(unet.parameters())
568
pipeline = DDPMPipeline(
569
unet=unet,
570
scheduler=noise_scheduler,
571
)
572
573
generator = torch.Generator(device=pipeline.device).manual_seed(0)
574
# run pipeline in inference (sample random noise and denoise)
575
images = pipeline(
576
generator=generator,
577
batch_size=args.eval_batch_size,
578
output_type="numpy",
579
num_inference_steps=args.ddpm_num_inference_steps,
580
).images
581
582
# denormalize the images and save to tensorboard
583
images_processed = (images * 255).round().astype("uint8")
584
585
if args.logger == "tensorboard":
586
accelerator.get_tracker("tensorboard").add_images(
587
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
588
)
589
elif args.logger == "wandb":
590
accelerator.get_tracker("wandb").log(
591
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
592
step=global_step,
593
)
594
595
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
596
# save the model
597
pipeline.save_pretrained(args.output_dir)
598
if args.push_to_hub:
599
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
600
601
accelerator.end_training()
602
603
604
if __name__ == "__main__":
605
args = parse_args()
606
main(args)
607
608