Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/dreambooth/train_dreambooth_flax.py
1441 views
1
import argparse
2
import hashlib
3
import logging
4
import math
5
import os
6
from pathlib import Path
7
from typing import Optional
8
9
import jax
10
import jax.numpy as jnp
11
import numpy as np
12
import optax
13
import torch
14
import torch.utils.checkpoint
15
import transformers
16
from flax import jax_utils
17
from flax.training import train_state
18
from flax.training.common_utils import shard
19
from huggingface_hub import HfFolder, Repository, create_repo, whoami
20
from jax.experimental.compilation_cache import compilation_cache as cc
21
from PIL import Image
22
from torch.utils.data import Dataset
23
from torchvision import transforms
24
from tqdm.auto import tqdm
25
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
26
27
from diffusers import (
28
FlaxAutoencoderKL,
29
FlaxDDPMScheduler,
30
FlaxPNDMScheduler,
31
FlaxStableDiffusionPipeline,
32
FlaxUNet2DConditionModel,
33
)
34
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
35
from diffusers.utils import check_min_version
36
37
38
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
check_min_version("0.15.0.dev0")
40
41
# Cache compiled models across invocations of this script.
42
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
43
44
logger = logging.getLogger(__name__)
45
46
47
def parse_args():
48
parser = argparse.ArgumentParser(description="Simple example of a training script.")
49
parser.add_argument(
50
"--pretrained_model_name_or_path",
51
type=str,
52
default=None,
53
required=True,
54
help="Path to pretrained model or model identifier from huggingface.co/models.",
55
)
56
parser.add_argument(
57
"--pretrained_vae_name_or_path",
58
type=str,
59
default=None,
60
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
61
)
62
parser.add_argument(
63
"--revision",
64
type=str,
65
default=None,
66
required=False,
67
help="Revision of pretrained model identifier from huggingface.co/models.",
68
)
69
parser.add_argument(
70
"--tokenizer_name",
71
type=str,
72
default=None,
73
help="Pretrained tokenizer name or path if not the same as model_name",
74
)
75
parser.add_argument(
76
"--instance_data_dir",
77
type=str,
78
default=None,
79
required=True,
80
help="A folder containing the training data of instance images.",
81
)
82
parser.add_argument(
83
"--class_data_dir",
84
type=str,
85
default=None,
86
required=False,
87
help="A folder containing the training data of class images.",
88
)
89
parser.add_argument(
90
"--instance_prompt",
91
type=str,
92
default=None,
93
help="The prompt with identifier specifying the instance",
94
)
95
parser.add_argument(
96
"--class_prompt",
97
type=str,
98
default=None,
99
help="The prompt to specify images in the same class as provided instance images.",
100
)
101
parser.add_argument(
102
"--with_prior_preservation",
103
default=False,
104
action="store_true",
105
help="Flag to add prior preservation loss.",
106
)
107
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
108
parser.add_argument(
109
"--num_class_images",
110
type=int,
111
default=100,
112
help=(
113
"Minimal class images for prior preservation loss. If there are not enough images already present in"
114
" class_data_dir, additional images will be sampled with class_prompt."
115
),
116
)
117
parser.add_argument(
118
"--output_dir",
119
type=str,
120
default="text-inversion-model",
121
help="The output directory where the model predictions and checkpoints will be written.",
122
)
123
parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
124
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
125
parser.add_argument(
126
"--resolution",
127
type=int,
128
default=512,
129
help=(
130
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
131
" resolution"
132
),
133
)
134
parser.add_argument(
135
"--center_crop",
136
default=False,
137
action="store_true",
138
help=(
139
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
140
" cropped. The images will be resized to the resolution first before cropping."
141
),
142
)
143
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
144
parser.add_argument(
145
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
146
)
147
parser.add_argument(
148
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
149
)
150
parser.add_argument("--num_train_epochs", type=int, default=1)
151
parser.add_argument(
152
"--max_train_steps",
153
type=int,
154
default=None,
155
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
156
)
157
parser.add_argument(
158
"--learning_rate",
159
type=float,
160
default=5e-6,
161
help="Initial learning rate (after the potential warmup period) to use.",
162
)
163
parser.add_argument(
164
"--scale_lr",
165
action="store_true",
166
default=False,
167
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
168
)
169
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
170
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
171
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
172
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
173
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
174
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
175
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
176
parser.add_argument(
177
"--hub_model_id",
178
type=str,
179
default=None,
180
help="The name of the repository to keep in sync with the local `output_dir`.",
181
)
182
parser.add_argument(
183
"--logging_dir",
184
type=str,
185
default="logs",
186
help=(
187
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
188
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
189
),
190
)
191
parser.add_argument(
192
"--mixed_precision",
193
type=str,
194
default="no",
195
choices=["no", "fp16", "bf16"],
196
help=(
197
"Whether to use mixed precision. Choose"
198
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
199
"and an Nvidia Ampere GPU."
200
),
201
)
202
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
203
204
args = parser.parse_args()
205
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
206
if env_local_rank != -1 and env_local_rank != args.local_rank:
207
args.local_rank = env_local_rank
208
209
if args.instance_data_dir is None:
210
raise ValueError("You must specify a train data directory.")
211
212
if args.with_prior_preservation:
213
if args.class_data_dir is None:
214
raise ValueError("You must specify a data directory for class images.")
215
if args.class_prompt is None:
216
raise ValueError("You must specify prompt for class images.")
217
218
return args
219
220
221
class DreamBoothDataset(Dataset):
222
"""
223
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
224
It pre-processes the images and the tokenizes prompts.
225
"""
226
227
def __init__(
228
self,
229
instance_data_root,
230
instance_prompt,
231
tokenizer,
232
class_data_root=None,
233
class_prompt=None,
234
class_num=None,
235
size=512,
236
center_crop=False,
237
):
238
self.size = size
239
self.center_crop = center_crop
240
self.tokenizer = tokenizer
241
242
self.instance_data_root = Path(instance_data_root)
243
if not self.instance_data_root.exists():
244
raise ValueError("Instance images root doesn't exists.")
245
246
self.instance_images_path = list(Path(instance_data_root).iterdir())
247
self.num_instance_images = len(self.instance_images_path)
248
self.instance_prompt = instance_prompt
249
self._length = self.num_instance_images
250
251
if class_data_root is not None:
252
self.class_data_root = Path(class_data_root)
253
self.class_data_root.mkdir(parents=True, exist_ok=True)
254
self.class_images_path = list(self.class_data_root.iterdir())
255
if class_num is not None:
256
self.num_class_images = min(len(self.class_images_path), class_num)
257
else:
258
self.num_class_images = len(self.class_images_path)
259
self._length = max(self.num_class_images, self.num_instance_images)
260
self.class_prompt = class_prompt
261
else:
262
self.class_data_root = None
263
264
self.image_transforms = transforms.Compose(
265
[
266
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
267
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
268
transforms.ToTensor(),
269
transforms.Normalize([0.5], [0.5]),
270
]
271
)
272
273
def __len__(self):
274
return self._length
275
276
def __getitem__(self, index):
277
example = {}
278
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
279
if not instance_image.mode == "RGB":
280
instance_image = instance_image.convert("RGB")
281
example["instance_images"] = self.image_transforms(instance_image)
282
example["instance_prompt_ids"] = self.tokenizer(
283
self.instance_prompt,
284
padding="do_not_pad",
285
truncation=True,
286
max_length=self.tokenizer.model_max_length,
287
).input_ids
288
289
if self.class_data_root:
290
class_image = Image.open(self.class_images_path[index % self.num_class_images])
291
if not class_image.mode == "RGB":
292
class_image = class_image.convert("RGB")
293
example["class_images"] = self.image_transforms(class_image)
294
example["class_prompt_ids"] = self.tokenizer(
295
self.class_prompt,
296
padding="do_not_pad",
297
truncation=True,
298
max_length=self.tokenizer.model_max_length,
299
).input_ids
300
301
return example
302
303
304
class PromptDataset(Dataset):
305
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
306
307
def __init__(self, prompt, num_samples):
308
self.prompt = prompt
309
self.num_samples = num_samples
310
311
def __len__(self):
312
return self.num_samples
313
314
def __getitem__(self, index):
315
example = {}
316
example["prompt"] = self.prompt
317
example["index"] = index
318
return example
319
320
321
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
322
if token is None:
323
token = HfFolder.get_token()
324
if organization is None:
325
username = whoami(token)["name"]
326
return f"{username}/{model_id}"
327
else:
328
return f"{organization}/{model_id}"
329
330
331
def get_params_to_save(params):
332
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
333
334
335
def main():
336
args = parse_args()
337
338
logging.basicConfig(
339
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
340
datefmt="%m/%d/%Y %H:%M:%S",
341
level=logging.INFO,
342
)
343
# Setup logging, we only want one process per machine to log things on the screen.
344
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
345
if jax.process_index() == 0:
346
transformers.utils.logging.set_verbosity_info()
347
else:
348
transformers.utils.logging.set_verbosity_error()
349
350
if args.seed is not None:
351
set_seed(args.seed)
352
353
rng = jax.random.PRNGKey(args.seed)
354
355
if args.with_prior_preservation:
356
class_images_dir = Path(args.class_data_dir)
357
if not class_images_dir.exists():
358
class_images_dir.mkdir(parents=True)
359
cur_class_images = len(list(class_images_dir.iterdir()))
360
361
if cur_class_images < args.num_class_images:
362
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
363
args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
364
)
365
pipeline.set_progress_bar_config(disable=True)
366
367
num_new_images = args.num_class_images - cur_class_images
368
logger.info(f"Number of class images to sample: {num_new_images}.")
369
370
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
371
total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
372
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
373
374
for example in tqdm(
375
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
376
):
377
prompt_ids = pipeline.prepare_inputs(example["prompt"])
378
prompt_ids = shard(prompt_ids)
379
p_params = jax_utils.replicate(params)
380
rng = jax.random.split(rng)[0]
381
sample_rng = jax.random.split(rng, jax.device_count())
382
images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
383
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
384
images = pipeline.numpy_to_pil(np.array(images))
385
386
for i, image in enumerate(images):
387
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
388
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
389
image.save(image_filename)
390
391
del pipeline
392
393
# Handle the repository creation
394
if jax.process_index() == 0:
395
if args.push_to_hub:
396
if args.hub_model_id is None:
397
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
398
else:
399
repo_name = args.hub_model_id
400
create_repo(repo_name, exist_ok=True, token=args.hub_token)
401
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
402
403
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
404
if "step_*" not in gitignore:
405
gitignore.write("step_*\n")
406
if "epoch_*" not in gitignore:
407
gitignore.write("epoch_*\n")
408
elif args.output_dir is not None:
409
os.makedirs(args.output_dir, exist_ok=True)
410
411
# Load the tokenizer and add the placeholder token as a additional special token
412
if args.tokenizer_name:
413
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
414
elif args.pretrained_model_name_or_path:
415
tokenizer = CLIPTokenizer.from_pretrained(
416
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
417
)
418
else:
419
raise NotImplementedError("No tokenizer specified!")
420
421
train_dataset = DreamBoothDataset(
422
instance_data_root=args.instance_data_dir,
423
instance_prompt=args.instance_prompt,
424
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
425
class_prompt=args.class_prompt,
426
class_num=args.num_class_images,
427
tokenizer=tokenizer,
428
size=args.resolution,
429
center_crop=args.center_crop,
430
)
431
432
def collate_fn(examples):
433
input_ids = [example["instance_prompt_ids"] for example in examples]
434
pixel_values = [example["instance_images"] for example in examples]
435
436
# Concat class and instance examples for prior preservation.
437
# We do this to avoid doing two forward passes.
438
if args.with_prior_preservation:
439
input_ids += [example["class_prompt_ids"] for example in examples]
440
pixel_values += [example["class_images"] for example in examples]
441
442
pixel_values = torch.stack(pixel_values)
443
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
444
445
input_ids = tokenizer.pad(
446
{"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
447
).input_ids
448
449
batch = {
450
"input_ids": input_ids,
451
"pixel_values": pixel_values,
452
}
453
batch = {k: v.numpy() for k, v in batch.items()}
454
return batch
455
456
total_train_batch_size = args.train_batch_size * jax.local_device_count()
457
if len(train_dataset) < total_train_batch_size:
458
raise ValueError(
459
f"Training batch size is {total_train_batch_size}, but your dataset only contains"
460
f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
461
f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
462
)
463
464
train_dataloader = torch.utils.data.DataLoader(
465
train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
466
)
467
468
weight_dtype = jnp.float32
469
if args.mixed_precision == "fp16":
470
weight_dtype = jnp.float16
471
elif args.mixed_precision == "bf16":
472
weight_dtype = jnp.bfloat16
473
474
if args.pretrained_vae_name_or_path:
475
# TODO(patil-suraj): Upload flax weights for the VAE
476
vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
477
else:
478
vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
479
480
# Load models and create wrapper for stable diffusion
481
text_encoder = FlaxCLIPTextModel.from_pretrained(
482
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
483
)
484
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
485
vae_arg,
486
dtype=weight_dtype,
487
**vae_kwargs,
488
)
489
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
490
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
491
)
492
493
# Optimization
494
if args.scale_lr:
495
args.learning_rate = args.learning_rate * total_train_batch_size
496
497
constant_scheduler = optax.constant_schedule(args.learning_rate)
498
499
adamw = optax.adamw(
500
learning_rate=constant_scheduler,
501
b1=args.adam_beta1,
502
b2=args.adam_beta2,
503
eps=args.adam_epsilon,
504
weight_decay=args.adam_weight_decay,
505
)
506
507
optimizer = optax.chain(
508
optax.clip_by_global_norm(args.max_grad_norm),
509
adamw,
510
)
511
512
unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
513
text_encoder_state = train_state.TrainState.create(
514
apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
515
)
516
517
noise_scheduler = FlaxDDPMScheduler(
518
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
519
)
520
noise_scheduler_state = noise_scheduler.create_state()
521
522
# Initialize our training
523
train_rngs = jax.random.split(rng, jax.local_device_count())
524
525
def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):
526
dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
527
528
if args.train_text_encoder:
529
params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}
530
else:
531
params = {"unet": unet_state.params}
532
533
def compute_loss(params):
534
# Convert images to latent space
535
vae_outputs = vae.apply(
536
{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
537
)
538
latents = vae_outputs.latent_dist.sample(sample_rng)
539
# (NHWC) -> (NCHW)
540
latents = jnp.transpose(latents, (0, 3, 1, 2))
541
latents = latents * vae.config.scaling_factor
542
543
# Sample noise that we'll add to the latents
544
noise_rng, timestep_rng = jax.random.split(sample_rng)
545
noise = jax.random.normal(noise_rng, latents.shape)
546
# Sample a random timestep for each image
547
bsz = latents.shape[0]
548
timesteps = jax.random.randint(
549
timestep_rng,
550
(bsz,),
551
0,
552
noise_scheduler.config.num_train_timesteps,
553
)
554
555
# Add noise to the latents according to the noise magnitude at each timestep
556
# (this is the forward diffusion process)
557
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
558
559
# Get the text embedding for conditioning
560
if args.train_text_encoder:
561
encoder_hidden_states = text_encoder_state.apply_fn(
562
batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True
563
)[0]
564
else:
565
encoder_hidden_states = text_encoder(
566
batch["input_ids"], params=text_encoder_state.params, train=False
567
)[0]
568
569
# Predict the noise residual
570
model_pred = unet.apply(
571
{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
572
).sample
573
574
# Get the target for loss depending on the prediction type
575
if noise_scheduler.config.prediction_type == "epsilon":
576
target = noise
577
elif noise_scheduler.config.prediction_type == "v_prediction":
578
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
579
else:
580
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
581
582
if args.with_prior_preservation:
583
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
584
model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
585
target, target_prior = jnp.split(target, 2, axis=0)
586
587
# Compute instance loss
588
loss = (target - model_pred) ** 2
589
loss = loss.mean()
590
591
# Compute prior loss
592
prior_loss = (target_prior - model_pred_prior) ** 2
593
prior_loss = prior_loss.mean()
594
595
# Add the prior loss to the instance loss.
596
loss = loss + args.prior_loss_weight * prior_loss
597
else:
598
loss = (target - model_pred) ** 2
599
loss = loss.mean()
600
601
return loss
602
603
grad_fn = jax.value_and_grad(compute_loss)
604
loss, grad = grad_fn(params)
605
grad = jax.lax.pmean(grad, "batch")
606
607
new_unet_state = unet_state.apply_gradients(grads=grad["unet"])
608
if args.train_text_encoder:
609
new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])
610
else:
611
new_text_encoder_state = text_encoder_state
612
613
metrics = {"loss": loss}
614
metrics = jax.lax.pmean(metrics, axis_name="batch")
615
616
return new_unet_state, new_text_encoder_state, metrics, new_train_rng
617
618
# Create parallel version of the train step
619
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
620
621
# Replicate the train state on each device
622
unet_state = jax_utils.replicate(unet_state)
623
text_encoder_state = jax_utils.replicate(text_encoder_state)
624
vae_params = jax_utils.replicate(vae_params)
625
626
# Train!
627
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
628
629
# Scheduler and math around the number of training steps.
630
if args.max_train_steps is None:
631
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
632
633
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
634
635
logger.info("***** Running training *****")
636
logger.info(f" Num examples = {len(train_dataset)}")
637
logger.info(f" Num Epochs = {args.num_train_epochs}")
638
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
639
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
640
logger.info(f" Total optimization steps = {args.max_train_steps}")
641
642
def checkpoint(step=None):
643
# Create the pipeline using the trained modules and save it.
644
scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
645
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
646
"CompVis/stable-diffusion-safety-checker", from_pt=True
647
)
648
pipeline = FlaxStableDiffusionPipeline(
649
text_encoder=text_encoder,
650
vae=vae,
651
unet=unet,
652
tokenizer=tokenizer,
653
scheduler=scheduler,
654
safety_checker=safety_checker,
655
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
656
)
657
658
outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
659
pipeline.save_pretrained(
660
outdir,
661
params={
662
"text_encoder": get_params_to_save(text_encoder_state.params),
663
"vae": get_params_to_save(vae_params),
664
"unet": get_params_to_save(unet_state.params),
665
"safety_checker": safety_checker.params,
666
},
667
)
668
669
if args.push_to_hub:
670
message = f"checkpoint-{step}" if step is not None else "End of training"
671
repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
672
673
global_step = 0
674
675
epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
676
for epoch in epochs:
677
# ======================== Training ================================
678
679
train_metrics = []
680
681
steps_per_epoch = len(train_dataset) // total_train_batch_size
682
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
683
# train
684
for batch in train_dataloader:
685
batch = shard(batch)
686
unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
687
unet_state, text_encoder_state, vae_params, batch, train_rngs
688
)
689
train_metrics.append(train_metric)
690
691
train_step_progress_bar.update(jax.local_device_count())
692
693
global_step += 1
694
if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
695
checkpoint(global_step)
696
if global_step >= args.max_train_steps:
697
break
698
699
train_metric = jax_utils.unreplicate(train_metric)
700
701
train_step_progress_bar.close()
702
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
703
704
if jax.process_index() == 0:
705
checkpoint()
706
707
708
if __name__ == "__main__":
709
main()
710
711