Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/unconditional_image_generation/train_unconditional.py
1448 views
1
import argparse
2
import inspect
3
import logging
4
import math
5
import os
6
from pathlib import Path
7
from typing import Optional
8
9
import accelerate
10
import datasets
11
import torch
12
import torch.nn.functional as F
13
from accelerate import Accelerator
14
from accelerate.logging import get_logger
15
from accelerate.utils import ProjectConfiguration
16
from datasets import load_dataset
17
from huggingface_hub import HfFolder, Repository, create_repo, whoami
18
from packaging import version
19
from torchvision import transforms
20
from tqdm.auto import tqdm
21
22
import diffusers
23
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
24
from diffusers.optimization import get_scheduler
25
from diffusers.training_utils import EMAModel
26
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
27
from diffusers.utils.import_utils import is_xformers_available
28
29
30
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
31
check_min_version("0.15.0.dev0")
32
33
logger = get_logger(__name__, log_level="INFO")
34
35
36
def _extract_into_tensor(arr, timesteps, broadcast_shape):
37
"""
38
Extract values from a 1-D numpy array for a batch of indices.
39
40
:param arr: the 1-D numpy array.
41
:param timesteps: a tensor of indices into the array to extract.
42
:param broadcast_shape: a larger shape of K dimensions with the batch
43
dimension equal to the length of timesteps.
44
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
45
"""
46
if not isinstance(arr, torch.Tensor):
47
arr = torch.from_numpy(arr)
48
res = arr[timesteps].float().to(timesteps.device)
49
while len(res.shape) < len(broadcast_shape):
50
res = res[..., None]
51
return res.expand(broadcast_shape)
52
53
54
def parse_args():
55
parser = argparse.ArgumentParser(description="Simple example of a training script.")
56
parser.add_argument(
57
"--dataset_name",
58
type=str,
59
default=None,
60
help=(
61
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
62
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
63
" or to a folder containing files that HF Datasets can understand."
64
),
65
)
66
parser.add_argument(
67
"--dataset_config_name",
68
type=str,
69
default=None,
70
help="The config of the Dataset, leave as None if there's only one config.",
71
)
72
parser.add_argument(
73
"--model_config_name_or_path",
74
type=str,
75
default=None,
76
help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
77
)
78
parser.add_argument(
79
"--train_data_dir",
80
type=str,
81
default=None,
82
help=(
83
"A folder containing the training data. Folder contents must follow the structure described in"
84
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
85
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
86
),
87
)
88
parser.add_argument(
89
"--output_dir",
90
type=str,
91
default="ddpm-model-64",
92
help="The output directory where the model predictions and checkpoints will be written.",
93
)
94
parser.add_argument("--overwrite_output_dir", action="store_true")
95
parser.add_argument(
96
"--cache_dir",
97
type=str,
98
default=None,
99
help="The directory where the downloaded models and datasets will be stored.",
100
)
101
parser.add_argument(
102
"--resolution",
103
type=int,
104
default=64,
105
help=(
106
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
107
" resolution"
108
),
109
)
110
parser.add_argument(
111
"--center_crop",
112
default=False,
113
action="store_true",
114
help=(
115
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
116
" cropped. The images will be resized to the resolution first before cropping."
117
),
118
)
119
parser.add_argument(
120
"--random_flip",
121
default=False,
122
action="store_true",
123
help="whether to randomly flip images horizontally",
124
)
125
parser.add_argument(
126
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
127
)
128
parser.add_argument(
129
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
130
)
131
parser.add_argument(
132
"--dataloader_num_workers",
133
type=int,
134
default=0,
135
help=(
136
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
137
" process."
138
),
139
)
140
parser.add_argument("--num_epochs", type=int, default=100)
141
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
142
parser.add_argument(
143
"--save_model_epochs", type=int, default=10, help="How often to save the model during training."
144
)
145
parser.add_argument(
146
"--gradient_accumulation_steps",
147
type=int,
148
default=1,
149
help="Number of updates steps to accumulate before performing a backward/update pass.",
150
)
151
parser.add_argument(
152
"--learning_rate",
153
type=float,
154
default=1e-4,
155
help="Initial learning rate (after the potential warmup period) to use.",
156
)
157
parser.add_argument(
158
"--lr_scheduler",
159
type=str,
160
default="cosine",
161
help=(
162
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
163
' "constant", "constant_with_warmup"]'
164
),
165
)
166
parser.add_argument(
167
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
168
)
169
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
170
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
171
parser.add_argument(
172
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
173
)
174
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
175
parser.add_argument(
176
"--use_ema",
177
action="store_true",
178
help="Whether to use Exponential Moving Average for the final model weights.",
179
)
180
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
181
parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
182
parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
183
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
184
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
185
parser.add_argument(
186
"--hub_model_id",
187
type=str,
188
default=None,
189
help="The name of the repository to keep in sync with the local `output_dir`.",
190
)
191
parser.add_argument(
192
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
193
)
194
parser.add_argument(
195
"--logger",
196
type=str,
197
default="tensorboard",
198
choices=["tensorboard", "wandb"],
199
help=(
200
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
201
" for experiment tracking and logging of model metrics and model checkpoints"
202
),
203
)
204
parser.add_argument(
205
"--logging_dir",
206
type=str,
207
default="logs",
208
help=(
209
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
210
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
211
),
212
)
213
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
214
parser.add_argument(
215
"--mixed_precision",
216
type=str,
217
default="no",
218
choices=["no", "fp16", "bf16"],
219
help=(
220
"Whether to use mixed precision. Choose"
221
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
222
"and an Nvidia Ampere GPU."
223
),
224
)
225
parser.add_argument(
226
"--prediction_type",
227
type=str,
228
default="epsilon",
229
choices=["epsilon", "sample"],
230
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
231
)
232
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
233
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
234
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
235
parser.add_argument(
236
"--checkpointing_steps",
237
type=int,
238
default=500,
239
help=(
240
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
241
" training using `--resume_from_checkpoint`."
242
),
243
)
244
parser.add_argument(
245
"--checkpoints_total_limit",
246
type=int,
247
default=None,
248
help=(
249
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
250
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
251
" for more docs"
252
),
253
)
254
parser.add_argument(
255
"--resume_from_checkpoint",
256
type=str,
257
default=None,
258
help=(
259
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
260
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
261
),
262
)
263
parser.add_argument(
264
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
265
)
266
267
args = parser.parse_args()
268
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
269
if env_local_rank != -1 and env_local_rank != args.local_rank:
270
args.local_rank = env_local_rank
271
272
if args.dataset_name is None and args.train_data_dir is None:
273
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
274
275
return args
276
277
278
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
279
if token is None:
280
token = HfFolder.get_token()
281
if organization is None:
282
username = whoami(token)["name"]
283
return f"{username}/{model_id}"
284
else:
285
return f"{organization}/{model_id}"
286
287
288
def main(args):
289
logging_dir = os.path.join(args.output_dir, args.logging_dir)
290
291
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
292
293
accelerator = Accelerator(
294
gradient_accumulation_steps=args.gradient_accumulation_steps,
295
mixed_precision=args.mixed_precision,
296
log_with=args.logger,
297
logging_dir=logging_dir,
298
project_config=accelerator_project_config,
299
)
300
301
if args.logger == "tensorboard":
302
if not is_tensorboard_available():
303
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
304
305
elif args.logger == "wandb":
306
if not is_wandb_available():
307
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
308
import wandb
309
310
# `accelerate` 0.16.0 will have better support for customized saving
311
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
312
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
313
def save_model_hook(models, weights, output_dir):
314
if args.use_ema:
315
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
316
317
for i, model in enumerate(models):
318
model.save_pretrained(os.path.join(output_dir, "unet"))
319
320
# make sure to pop weight so that corresponding model is not saved again
321
weights.pop()
322
323
def load_model_hook(models, input_dir):
324
if args.use_ema:
325
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
326
ema_model.load_state_dict(load_model.state_dict())
327
ema_model.to(accelerator.device)
328
del load_model
329
330
for i in range(len(models)):
331
# pop models so that they are not loaded again
332
model = models.pop()
333
334
# load diffusers style into model
335
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
336
model.register_to_config(**load_model.config)
337
338
model.load_state_dict(load_model.state_dict())
339
del load_model
340
341
accelerator.register_save_state_pre_hook(save_model_hook)
342
accelerator.register_load_state_pre_hook(load_model_hook)
343
344
# Make one log on every process with the configuration for debugging.
345
logging.basicConfig(
346
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
347
datefmt="%m/%d/%Y %H:%M:%S",
348
level=logging.INFO,
349
)
350
logger.info(accelerator.state, main_process_only=False)
351
if accelerator.is_local_main_process:
352
datasets.utils.logging.set_verbosity_warning()
353
diffusers.utils.logging.set_verbosity_info()
354
else:
355
datasets.utils.logging.set_verbosity_error()
356
diffusers.utils.logging.set_verbosity_error()
357
358
# Handle the repository creation
359
if accelerator.is_main_process:
360
if args.push_to_hub:
361
if args.hub_model_id is None:
362
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
363
else:
364
repo_name = args.hub_model_id
365
create_repo(repo_name, exist_ok=True, token=args.hub_token)
366
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
367
368
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
369
if "step_*" not in gitignore:
370
gitignore.write("step_*\n")
371
if "epoch_*" not in gitignore:
372
gitignore.write("epoch_*\n")
373
elif args.output_dir is not None:
374
os.makedirs(args.output_dir, exist_ok=True)
375
376
# Initialize the model
377
if args.model_config_name_or_path is None:
378
model = UNet2DModel(
379
sample_size=args.resolution,
380
in_channels=3,
381
out_channels=3,
382
layers_per_block=2,
383
block_out_channels=(128, 128, 256, 256, 512, 512),
384
down_block_types=(
385
"DownBlock2D",
386
"DownBlock2D",
387
"DownBlock2D",
388
"DownBlock2D",
389
"AttnDownBlock2D",
390
"DownBlock2D",
391
),
392
up_block_types=(
393
"UpBlock2D",
394
"AttnUpBlock2D",
395
"UpBlock2D",
396
"UpBlock2D",
397
"UpBlock2D",
398
"UpBlock2D",
399
),
400
)
401
else:
402
config = UNet2DModel.load_config(args.model_config_name_or_path)
403
model = UNet2DModel.from_config(config)
404
405
# Create EMA for the model.
406
if args.use_ema:
407
ema_model = EMAModel(
408
model.parameters(),
409
decay=args.ema_max_decay,
410
use_ema_warmup=True,
411
inv_gamma=args.ema_inv_gamma,
412
power=args.ema_power,
413
model_cls=UNet2DModel,
414
model_config=model.config,
415
)
416
417
if args.enable_xformers_memory_efficient_attention:
418
if is_xformers_available():
419
import xformers
420
421
xformers_version = version.parse(xformers.__version__)
422
if xformers_version == version.parse("0.0.16"):
423
logger.warn(
424
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
425
)
426
model.enable_xformers_memory_efficient_attention()
427
else:
428
raise ValueError("xformers is not available. Make sure it is installed correctly")
429
430
# Initialize the scheduler
431
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
432
if accepts_prediction_type:
433
noise_scheduler = DDPMScheduler(
434
num_train_timesteps=args.ddpm_num_steps,
435
beta_schedule=args.ddpm_beta_schedule,
436
prediction_type=args.prediction_type,
437
)
438
else:
439
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
440
441
# Initialize the optimizer
442
optimizer = torch.optim.AdamW(
443
model.parameters(),
444
lr=args.learning_rate,
445
betas=(args.adam_beta1, args.adam_beta2),
446
weight_decay=args.adam_weight_decay,
447
eps=args.adam_epsilon,
448
)
449
450
# Get the datasets: you can either provide your own training and evaluation files (see below)
451
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
452
453
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
454
# download the dataset.
455
if args.dataset_name is not None:
456
dataset = load_dataset(
457
args.dataset_name,
458
args.dataset_config_name,
459
cache_dir=args.cache_dir,
460
split="train",
461
)
462
else:
463
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
464
# See more about loading custom images at
465
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
466
467
# Preprocessing the datasets and DataLoaders creation.
468
augmentations = transforms.Compose(
469
[
470
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
471
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
472
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
473
transforms.ToTensor(),
474
transforms.Normalize([0.5], [0.5]),
475
]
476
)
477
478
def transform_images(examples):
479
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
480
return {"input": images}
481
482
logger.info(f"Dataset size: {len(dataset)}")
483
484
dataset.set_transform(transform_images)
485
train_dataloader = torch.utils.data.DataLoader(
486
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
487
)
488
489
# Initialize the learning rate scheduler
490
lr_scheduler = get_scheduler(
491
args.lr_scheduler,
492
optimizer=optimizer,
493
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
494
num_training_steps=(len(train_dataloader) * args.num_epochs),
495
)
496
497
# Prepare everything with our `accelerator`.
498
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
499
model, optimizer, train_dataloader, lr_scheduler
500
)
501
502
if args.use_ema:
503
ema_model.to(accelerator.device)
504
505
# We need to initialize the trackers we use, and also store our configuration.
506
# The trackers initializes automatically on the main process.
507
if accelerator.is_main_process:
508
run = os.path.split(__file__)[-1].split(".")[0]
509
accelerator.init_trackers(run)
510
511
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
512
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
513
max_train_steps = args.num_epochs * num_update_steps_per_epoch
514
515
logger.info("***** Running training *****")
516
logger.info(f" Num examples = {len(dataset)}")
517
logger.info(f" Num Epochs = {args.num_epochs}")
518
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
519
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
520
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
521
logger.info(f" Total optimization steps = {max_train_steps}")
522
523
global_step = 0
524
first_epoch = 0
525
526
# Potentially load in the weights and states from a previous save
527
if args.resume_from_checkpoint:
528
if args.resume_from_checkpoint != "latest":
529
path = os.path.basename(args.resume_from_checkpoint)
530
else:
531
# Get the most recent checkpoint
532
dirs = os.listdir(args.output_dir)
533
dirs = [d for d in dirs if d.startswith("checkpoint")]
534
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
535
path = dirs[-1] if len(dirs) > 0 else None
536
537
if path is None:
538
accelerator.print(
539
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
540
)
541
args.resume_from_checkpoint = None
542
else:
543
accelerator.print(f"Resuming from checkpoint {path}")
544
accelerator.load_state(os.path.join(args.output_dir, path))
545
global_step = int(path.split("-")[1])
546
547
resume_global_step = global_step * args.gradient_accumulation_steps
548
first_epoch = global_step // num_update_steps_per_epoch
549
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
550
551
# Train!
552
for epoch in range(first_epoch, args.num_epochs):
553
model.train()
554
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
555
progress_bar.set_description(f"Epoch {epoch}")
556
for step, batch in enumerate(train_dataloader):
557
# Skip steps until we reach the resumed step
558
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
559
if step % args.gradient_accumulation_steps == 0:
560
progress_bar.update(1)
561
continue
562
563
clean_images = batch["input"]
564
# Sample noise that we'll add to the images
565
noise = torch.randn(clean_images.shape).to(clean_images.device)
566
bsz = clean_images.shape[0]
567
# Sample a random timestep for each image
568
timesteps = torch.randint(
569
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
570
).long()
571
572
# Add noise to the clean images according to the noise magnitude at each timestep
573
# (this is the forward diffusion process)
574
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
575
576
with accelerator.accumulate(model):
577
# Predict the noise residual
578
model_output = model(noisy_images, timesteps).sample
579
580
if args.prediction_type == "epsilon":
581
loss = F.mse_loss(model_output, noise) # this could have different weights!
582
elif args.prediction_type == "sample":
583
alpha_t = _extract_into_tensor(
584
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
585
)
586
snr_weights = alpha_t / (1 - alpha_t)
587
loss = snr_weights * F.mse_loss(
588
model_output, clean_images, reduction="none"
589
) # use SNR weighting from distillation paper
590
loss = loss.mean()
591
else:
592
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
593
594
accelerator.backward(loss)
595
596
if accelerator.sync_gradients:
597
accelerator.clip_grad_norm_(model.parameters(), 1.0)
598
optimizer.step()
599
lr_scheduler.step()
600
optimizer.zero_grad()
601
602
# Checks if the accelerator has performed an optimization step behind the scenes
603
if accelerator.sync_gradients:
604
if args.use_ema:
605
ema_model.step(model.parameters())
606
progress_bar.update(1)
607
global_step += 1
608
609
if global_step % args.checkpointing_steps == 0:
610
if accelerator.is_main_process:
611
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
612
accelerator.save_state(save_path)
613
logger.info(f"Saved state to {save_path}")
614
615
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
616
if args.use_ema:
617
logs["ema_decay"] = ema_model.cur_decay_value
618
progress_bar.set_postfix(**logs)
619
accelerator.log(logs, step=global_step)
620
progress_bar.close()
621
622
accelerator.wait_for_everyone()
623
624
# Generate sample images for visual inspection
625
if accelerator.is_main_process:
626
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
627
unet = accelerator.unwrap_model(model)
628
629
if args.use_ema:
630
ema_model.store(unet.parameters())
631
ema_model.copy_to(unet.parameters())
632
633
pipeline = DDPMPipeline(
634
unet=unet,
635
scheduler=noise_scheduler,
636
)
637
638
generator = torch.Generator(device=pipeline.device).manual_seed(0)
639
# run pipeline in inference (sample random noise and denoise)
640
images = pipeline(
641
generator=generator,
642
batch_size=args.eval_batch_size,
643
num_inference_steps=args.ddpm_num_inference_steps,
644
output_type="numpy",
645
).images
646
647
if args.use_ema:
648
ema_model.restore(unet.parameters())
649
650
# denormalize the images and save to tensorboard
651
images_processed = (images * 255).round().astype("uint8")
652
653
if args.logger == "tensorboard":
654
if is_accelerate_version(">=", "0.17.0.dev0"):
655
tracker = accelerator.get_tracker("tensorboard", unwrap=True)
656
else:
657
tracker = accelerator.get_tracker("tensorboard")
658
tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
659
elif args.logger == "wandb":
660
# Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
661
accelerator.get_tracker("wandb").log(
662
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
663
step=global_step,
664
)
665
666
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
667
# save the model
668
unet = accelerator.unwrap_model(model)
669
670
if args.use_ema:
671
ema_model.store(unet.parameters())
672
ema_model.copy_to(unet.parameters())
673
674
pipeline = DDPMPipeline(
675
unet=unet,
676
scheduler=noise_scheduler,
677
)
678
679
pipeline.save_pretrained(args.output_dir)
680
681
if args.use_ema:
682
ema_model.restore(unet.parameters())
683
684
if args.push_to_hub:
685
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
686
687
accelerator.end_training()
688
689
690
if __name__ == "__main__":
691
args = parse_args()
692
main(args)
693
694