Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ShivamShrirao
GitHub Repository: ShivamShrirao/diffusers
Path: blob/main/examples/research_projects/colossalai/train_dreambooth_colossalai.py
1979 views
1
import argparse
2
import hashlib
3
import math
4
import os
5
from pathlib import Path
6
from typing import Optional
7
8
import colossalai
9
import torch
10
import torch.nn.functional as F
11
import torch.utils.checkpoint
12
from colossalai.context.parallel_mode import ParallelMode
13
from colossalai.core import global_context as gpc
14
from colossalai.logging import disable_existing_loggers, get_dist_logger
15
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
16
from colossalai.nn.parallel.utils import get_static_torch_model
17
from colossalai.utils import get_current_device
18
from colossalai.utils.model.colo_init_context import ColoInitContext
19
from huggingface_hub import HfFolder, Repository, create_repo, whoami
20
from PIL import Image
21
from torch.utils.data import Dataset
22
from torchvision import transforms
23
from tqdm.auto import tqdm
24
from transformers import AutoTokenizer, PretrainedConfig
25
26
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
27
from diffusers.optimization import get_scheduler
28
29
30
disable_existing_loggers()
31
logger = get_dist_logger()
32
33
34
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
35
text_encoder_config = PretrainedConfig.from_pretrained(
36
pretrained_model_name_or_path,
37
subfolder="text_encoder",
38
revision=args.revision,
39
)
40
model_class = text_encoder_config.architectures[0]
41
42
if model_class == "CLIPTextModel":
43
from transformers import CLIPTextModel
44
45
return CLIPTextModel
46
elif model_class == "RobertaSeriesModelWithTransformation":
47
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
48
49
return RobertaSeriesModelWithTransformation
50
else:
51
raise ValueError(f"{model_class} is not supported.")
52
53
54
def parse_args(input_args=None):
55
parser = argparse.ArgumentParser(description="Simple example of a training script.")
56
parser.add_argument(
57
"--pretrained_model_name_or_path",
58
type=str,
59
default=None,
60
required=True,
61
help="Path to pretrained model or model identifier from huggingface.co/models.",
62
)
63
parser.add_argument(
64
"--revision",
65
type=str,
66
default=None,
67
required=False,
68
help="Revision of pretrained model identifier from huggingface.co/models.",
69
)
70
parser.add_argument(
71
"--tokenizer_name",
72
type=str,
73
default=None,
74
help="Pretrained tokenizer name or path if not the same as model_name",
75
)
76
parser.add_argument(
77
"--instance_data_dir",
78
type=str,
79
default=None,
80
required=True,
81
help="A folder containing the training data of instance images.",
82
)
83
parser.add_argument(
84
"--class_data_dir",
85
type=str,
86
default=None,
87
required=False,
88
help="A folder containing the training data of class images.",
89
)
90
parser.add_argument(
91
"--instance_prompt",
92
type=str,
93
default="a photo of sks dog",
94
required=False,
95
help="The prompt with identifier specifying the instance",
96
)
97
parser.add_argument(
98
"--class_prompt",
99
type=str,
100
default=None,
101
help="The prompt to specify images in the same class as provided instance images.",
102
)
103
parser.add_argument(
104
"--with_prior_preservation",
105
default=False,
106
action="store_true",
107
help="Flag to add prior preservation loss.",
108
)
109
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
110
parser.add_argument(
111
"--num_class_images",
112
type=int,
113
default=100,
114
help=(
115
"Minimal class images for prior preservation loss. If there are not enough images already present in"
116
" class_data_dir, additional images will be sampled with class_prompt."
117
),
118
)
119
parser.add_argument(
120
"--output_dir",
121
type=str,
122
default="text-inversion-model",
123
help="The output directory where the model predictions and checkpoints will be written.",
124
)
125
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
126
parser.add_argument(
127
"--resolution",
128
type=int,
129
default=512,
130
help=(
131
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
132
" resolution"
133
),
134
)
135
parser.add_argument(
136
"--placement",
137
type=str,
138
default="cpu",
139
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
140
)
141
parser.add_argument(
142
"--center_crop",
143
default=False,
144
action="store_true",
145
help=(
146
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
147
" cropped. The images will be resized to the resolution first before cropping."
148
),
149
)
150
parser.add_argument(
151
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
152
)
153
parser.add_argument(
154
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
155
)
156
parser.add_argument("--num_train_epochs", type=int, default=1)
157
parser.add_argument(
158
"--max_train_steps",
159
type=int,
160
default=None,
161
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
162
)
163
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
164
parser.add_argument(
165
"--gradient_checkpointing",
166
action="store_true",
167
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
168
)
169
parser.add_argument(
170
"--learning_rate",
171
type=float,
172
default=5e-6,
173
help="Initial learning rate (after the potential warmup period) to use.",
174
)
175
parser.add_argument(
176
"--scale_lr",
177
action="store_true",
178
default=False,
179
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
180
)
181
parser.add_argument(
182
"--lr_scheduler",
183
type=str,
184
default="constant",
185
help=(
186
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
187
' "constant", "constant_with_warmup"]'
188
),
189
)
190
parser.add_argument(
191
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
192
)
193
parser.add_argument(
194
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
195
)
196
197
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
198
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
199
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
200
parser.add_argument(
201
"--hub_model_id",
202
type=str,
203
default=None,
204
help="The name of the repository to keep in sync with the local `output_dir`.",
205
)
206
parser.add_argument(
207
"--logging_dir",
208
type=str,
209
default="logs",
210
help=(
211
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
212
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
213
),
214
)
215
parser.add_argument(
216
"--mixed_precision",
217
type=str,
218
default=None,
219
choices=["no", "fp16", "bf16"],
220
help=(
221
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
222
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
223
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
224
),
225
)
226
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
227
228
if input_args is not None:
229
args = parser.parse_args(input_args)
230
else:
231
args = parser.parse_args()
232
233
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
234
if env_local_rank != -1 and env_local_rank != args.local_rank:
235
args.local_rank = env_local_rank
236
237
if args.with_prior_preservation:
238
if args.class_data_dir is None:
239
raise ValueError("You must specify a data directory for class images.")
240
if args.class_prompt is None:
241
raise ValueError("You must specify prompt for class images.")
242
else:
243
if args.class_data_dir is not None:
244
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
245
if args.class_prompt is not None:
246
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
247
248
return args
249
250
251
class DreamBoothDataset(Dataset):
252
"""
253
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
254
It pre-processes the images and the tokenizes prompts.
255
"""
256
257
def __init__(
258
self,
259
instance_data_root,
260
instance_prompt,
261
tokenizer,
262
class_data_root=None,
263
class_prompt=None,
264
size=512,
265
center_crop=False,
266
):
267
self.size = size
268
self.center_crop = center_crop
269
self.tokenizer = tokenizer
270
271
self.instance_data_root = Path(instance_data_root)
272
if not self.instance_data_root.exists():
273
raise ValueError("Instance images root doesn't exists.")
274
275
self.instance_images_path = list(Path(instance_data_root).iterdir())
276
self.num_instance_images = len(self.instance_images_path)
277
self.instance_prompt = instance_prompt
278
self._length = self.num_instance_images
279
280
if class_data_root is not None:
281
self.class_data_root = Path(class_data_root)
282
self.class_data_root.mkdir(parents=True, exist_ok=True)
283
self.class_images_path = list(self.class_data_root.iterdir())
284
self.num_class_images = len(self.class_images_path)
285
self._length = max(self.num_class_images, self.num_instance_images)
286
self.class_prompt = class_prompt
287
else:
288
self.class_data_root = None
289
290
self.image_transforms = transforms.Compose(
291
[
292
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
293
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
294
transforms.ToTensor(),
295
transforms.Normalize([0.5], [0.5]),
296
]
297
)
298
299
def __len__(self):
300
return self._length
301
302
def __getitem__(self, index):
303
example = {}
304
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
305
if not instance_image.mode == "RGB":
306
instance_image = instance_image.convert("RGB")
307
example["instance_images"] = self.image_transforms(instance_image)
308
example["instance_prompt_ids"] = self.tokenizer(
309
self.instance_prompt,
310
padding="do_not_pad",
311
truncation=True,
312
max_length=self.tokenizer.model_max_length,
313
).input_ids
314
315
if self.class_data_root:
316
class_image = Image.open(self.class_images_path[index % self.num_class_images])
317
if not class_image.mode == "RGB":
318
class_image = class_image.convert("RGB")
319
example["class_images"] = self.image_transforms(class_image)
320
example["class_prompt_ids"] = self.tokenizer(
321
self.class_prompt,
322
padding="do_not_pad",
323
truncation=True,
324
max_length=self.tokenizer.model_max_length,
325
).input_ids
326
327
return example
328
329
330
class PromptDataset(Dataset):
331
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
332
333
def __init__(self, prompt, num_samples):
334
self.prompt = prompt
335
self.num_samples = num_samples
336
337
def __len__(self):
338
return self.num_samples
339
340
def __getitem__(self, index):
341
example = {}
342
example["prompt"] = self.prompt
343
example["index"] = index
344
return example
345
346
347
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
348
if token is None:
349
token = HfFolder.get_token()
350
if organization is None:
351
username = whoami(token)["name"]
352
return f"{username}/{model_id}"
353
else:
354
return f"{organization}/{model_id}"
355
356
357
# Gemini + ZeRO DDP
358
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
359
from colossalai.nn.parallel import GeminiDDP
360
361
model = GeminiDDP(
362
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64
363
)
364
return model
365
366
367
def main(args):
368
if args.seed is None:
369
colossalai.launch_from_torch(config={})
370
else:
371
colossalai.launch_from_torch(config={}, seed=args.seed)
372
373
local_rank = gpc.get_local_rank(ParallelMode.DATA)
374
world_size = gpc.get_world_size(ParallelMode.DATA)
375
376
if args.with_prior_preservation:
377
class_images_dir = Path(args.class_data_dir)
378
if not class_images_dir.exists():
379
class_images_dir.mkdir(parents=True)
380
cur_class_images = len(list(class_images_dir.iterdir()))
381
382
if cur_class_images < args.num_class_images:
383
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
384
pipeline = DiffusionPipeline.from_pretrained(
385
args.pretrained_model_name_or_path,
386
torch_dtype=torch_dtype,
387
safety_checker=None,
388
revision=args.revision,
389
)
390
pipeline.set_progress_bar_config(disable=True)
391
392
num_new_images = args.num_class_images - cur_class_images
393
logger.info(f"Number of class images to sample: {num_new_images}.")
394
395
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
396
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
397
398
pipeline.to(get_current_device())
399
400
for example in tqdm(
401
sample_dataloader,
402
desc="Generating class images",
403
disable=not local_rank == 0,
404
):
405
images = pipeline(example["prompt"]).images
406
407
for i, image in enumerate(images):
408
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
409
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
410
image.save(image_filename)
411
412
del pipeline
413
414
# Handle the repository creation
415
if local_rank == 0:
416
if args.push_to_hub:
417
if args.hub_model_id is None:
418
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
419
else:
420
repo_name = args.hub_model_id
421
create_repo(repo_name, exist_ok=True, token=args.hub_token)
422
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
423
424
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
425
if "step_*" not in gitignore:
426
gitignore.write("step_*\n")
427
if "epoch_*" not in gitignore:
428
gitignore.write("epoch_*\n")
429
elif args.output_dir is not None:
430
os.makedirs(args.output_dir, exist_ok=True)
431
432
# Load the tokenizer
433
if args.tokenizer_name:
434
logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])
435
tokenizer = AutoTokenizer.from_pretrained(
436
args.tokenizer_name,
437
revision=args.revision,
438
use_fast=False,
439
)
440
elif args.pretrained_model_name_or_path:
441
logger.info("Loading tokenizer from pretrained model", ranks=[0])
442
tokenizer = AutoTokenizer.from_pretrained(
443
args.pretrained_model_name_or_path,
444
subfolder="tokenizer",
445
revision=args.revision,
446
use_fast=False,
447
)
448
# import correct text encoder class
449
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
450
451
# Load models and create wrapper for stable diffusion
452
453
logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
454
455
text_encoder = text_encoder_cls.from_pretrained(
456
args.pretrained_model_name_or_path,
457
subfolder="text_encoder",
458
revision=args.revision,
459
)
460
461
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
462
vae = AutoencoderKL.from_pretrained(
463
args.pretrained_model_name_or_path,
464
subfolder="vae",
465
revision=args.revision,
466
)
467
468
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
469
with ColoInitContext(device=get_current_device()):
470
unet = UNet2DConditionModel.from_pretrained(
471
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
472
)
473
474
vae.requires_grad_(False)
475
text_encoder.requires_grad_(False)
476
477
if args.gradient_checkpointing:
478
unet.enable_gradient_checkpointing()
479
480
if args.scale_lr:
481
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
482
483
unet = gemini_zero_dpp(unet, args.placement)
484
485
# config optimizer for colossalai zero
486
optimizer = GeminiAdamOptimizer(
487
unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
488
)
489
490
# load noise_scheduler
491
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
492
493
# prepare dataset
494
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
495
train_dataset = DreamBoothDataset(
496
instance_data_root=args.instance_data_dir,
497
instance_prompt=args.instance_prompt,
498
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
499
class_prompt=args.class_prompt,
500
tokenizer=tokenizer,
501
size=args.resolution,
502
center_crop=args.center_crop,
503
)
504
505
def collate_fn(examples):
506
input_ids = [example["instance_prompt_ids"] for example in examples]
507
pixel_values = [example["instance_images"] for example in examples]
508
509
# Concat class and instance examples for prior preservation.
510
# We do this to avoid doing two forward passes.
511
if args.with_prior_preservation:
512
input_ids += [example["class_prompt_ids"] for example in examples]
513
pixel_values += [example["class_images"] for example in examples]
514
515
pixel_values = torch.stack(pixel_values)
516
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
517
518
input_ids = tokenizer.pad(
519
{"input_ids": input_ids},
520
padding="max_length",
521
max_length=tokenizer.model_max_length,
522
return_tensors="pt",
523
).input_ids
524
525
batch = {
526
"input_ids": input_ids,
527
"pixel_values": pixel_values,
528
}
529
return batch
530
531
train_dataloader = torch.utils.data.DataLoader(
532
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
533
)
534
535
# Scheduler and math around the number of training steps.
536
overrode_max_train_steps = False
537
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
538
if args.max_train_steps is None:
539
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
540
overrode_max_train_steps = True
541
542
lr_scheduler = get_scheduler(
543
args.lr_scheduler,
544
optimizer=optimizer,
545
num_warmup_steps=args.lr_warmup_steps,
546
num_training_steps=args.max_train_steps,
547
)
548
weight_dtype = torch.float32
549
if args.mixed_precision == "fp16":
550
weight_dtype = torch.float16
551
elif args.mixed_precision == "bf16":
552
weight_dtype = torch.bfloat16
553
554
# Move text_encode and vae to gpu.
555
# For mixed precision training we cast the text_encoder and vae weights to half-precision
556
# as these models are only used for inference, keeping weights in full precision is not required.
557
vae.to(get_current_device(), dtype=weight_dtype)
558
text_encoder.to(get_current_device(), dtype=weight_dtype)
559
560
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
561
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
562
if overrode_max_train_steps:
563
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
564
# Afterwards we recalculate our number of training epochs
565
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
566
567
# Train!
568
total_batch_size = args.train_batch_size * world_size
569
570
logger.info("***** Running training *****", ranks=[0])
571
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
572
logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
573
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
574
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
575
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
576
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
577
578
# Only show the progress bar once on each machine.
579
progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
580
progress_bar.set_description("Steps")
581
global_step = 0
582
583
torch.cuda.synchronize()
584
for epoch in range(args.num_train_epochs):
585
unet.train()
586
for step, batch in enumerate(train_dataloader):
587
torch.cuda.reset_peak_memory_stats()
588
# Move batch to gpu
589
for key, value in batch.items():
590
batch[key] = value.to(get_current_device(), non_blocking=True)
591
592
# Convert images to latent space
593
optimizer.zero_grad()
594
595
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
596
latents = latents * 0.18215
597
598
# Sample noise that we'll add to the latents
599
noise = torch.randn_like(latents)
600
bsz = latents.shape[0]
601
# Sample a random timestep for each image
602
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
603
timesteps = timesteps.long()
604
605
# Add noise to the latents according to the noise magnitude at each timestep
606
# (this is the forward diffusion process)
607
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
608
609
# Get the text embedding for conditioning
610
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
611
612
# Predict the noise residual
613
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
614
615
# Get the target for loss depending on the prediction type
616
if noise_scheduler.config.prediction_type == "epsilon":
617
target = noise
618
elif noise_scheduler.config.prediction_type == "v_prediction":
619
target = noise_scheduler.get_velocity(latents, noise, timesteps)
620
else:
621
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
622
623
if args.with_prior_preservation:
624
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
625
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
626
target, target_prior = torch.chunk(target, 2, dim=0)
627
628
# Compute instance loss
629
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
630
631
# Compute prior loss
632
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
633
634
# Add the prior loss to the instance loss.
635
loss = loss + args.prior_loss_weight * prior_loss
636
else:
637
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
638
639
optimizer.backward(loss)
640
641
optimizer.step()
642
lr_scheduler.step()
643
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
644
# Checks if the accelerator has performed an optimization step behind the scenes
645
progress_bar.update(1)
646
global_step += 1
647
logs = {
648
"loss": loss.detach().item(),
649
"lr": optimizer.param_groups[0]["lr"],
650
} # lr_scheduler.get_last_lr()[0]}
651
progress_bar.set_postfix(**logs)
652
653
if global_step % args.save_steps == 0:
654
torch.cuda.synchronize()
655
torch_unet = get_static_torch_model(unet)
656
if local_rank == 0:
657
pipeline = DiffusionPipeline.from_pretrained(
658
args.pretrained_model_name_or_path,
659
unet=torch_unet,
660
revision=args.revision,
661
)
662
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
663
pipeline.save_pretrained(save_path)
664
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
665
if global_step >= args.max_train_steps:
666
break
667
668
torch.cuda.synchronize()
669
unet = get_static_torch_model(unet)
670
671
if local_rank == 0:
672
pipeline = DiffusionPipeline.from_pretrained(
673
args.pretrained_model_name_or_path,
674
unet=unet,
675
revision=args.revision,
676
)
677
678
pipeline.save_pretrained(args.output_dir)
679
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
680
681
if args.push_to_hub:
682
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
683
684
685
if __name__ == "__main__":
686
args = parse_args()
687
main(args)
688
689