Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/textual_inversion/textual_inversion_flax.py
1448 views
1
import argparse
2
import logging
3
import math
4
import os
5
import random
6
from pathlib import Path
7
from typing import Optional
8
9
import jax
10
import jax.numpy as jnp
11
import numpy as np
12
import optax
13
import PIL
14
import torch
15
import torch.utils.checkpoint
16
import transformers
17
from flax import jax_utils
18
from flax.training import train_state
19
from flax.training.common_utils import shard
20
from huggingface_hub import HfFolder, Repository, create_repo, whoami
21
22
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
23
from packaging import version
24
from PIL import Image
25
from torch.utils.data import Dataset
26
from torchvision import transforms
27
from tqdm.auto import tqdm
28
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
29
30
from diffusers import (
31
FlaxAutoencoderKL,
32
FlaxDDPMScheduler,
33
FlaxPNDMScheduler,
34
FlaxStableDiffusionPipeline,
35
FlaxUNet2DConditionModel,
36
)
37
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
38
from diffusers.utils import check_min_version
39
40
41
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
42
PIL_INTERPOLATION = {
43
"linear": PIL.Image.Resampling.BILINEAR,
44
"bilinear": PIL.Image.Resampling.BILINEAR,
45
"bicubic": PIL.Image.Resampling.BICUBIC,
46
"lanczos": PIL.Image.Resampling.LANCZOS,
47
"nearest": PIL.Image.Resampling.NEAREST,
48
}
49
else:
50
PIL_INTERPOLATION = {
51
"linear": PIL.Image.LINEAR,
52
"bilinear": PIL.Image.BILINEAR,
53
"bicubic": PIL.Image.BICUBIC,
54
"lanczos": PIL.Image.LANCZOS,
55
"nearest": PIL.Image.NEAREST,
56
}
57
# ------------------------------------------------------------------------------
58
59
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
check_min_version("0.15.0.dev0")
61
62
logger = logging.getLogger(__name__)
63
64
65
def parse_args():
66
parser = argparse.ArgumentParser(description="Simple example of a training script.")
67
parser.add_argument(
68
"--pretrained_model_name_or_path",
69
type=str,
70
default=None,
71
required=True,
72
help="Path to pretrained model or model identifier from huggingface.co/models.",
73
)
74
parser.add_argument(
75
"--tokenizer_name",
76
type=str,
77
default=None,
78
help="Pretrained tokenizer name or path if not the same as model_name",
79
)
80
parser.add_argument(
81
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
82
)
83
parser.add_argument(
84
"--placeholder_token",
85
type=str,
86
default=None,
87
required=True,
88
help="A token to use as a placeholder for the concept.",
89
)
90
parser.add_argument(
91
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
92
)
93
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
94
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
95
parser.add_argument(
96
"--output_dir",
97
type=str,
98
default="text-inversion-model",
99
help="The output directory where the model predictions and checkpoints will be written.",
100
)
101
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
102
parser.add_argument(
103
"--resolution",
104
type=int,
105
default=512,
106
help=(
107
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
108
" resolution"
109
),
110
)
111
parser.add_argument(
112
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
113
)
114
parser.add_argument(
115
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
116
)
117
parser.add_argument("--num_train_epochs", type=int, default=100)
118
parser.add_argument(
119
"--max_train_steps",
120
type=int,
121
default=5000,
122
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123
)
124
parser.add_argument(
125
"--save_steps",
126
type=int,
127
default=500,
128
help="Save learned_embeds.bin every X updates steps.",
129
)
130
parser.add_argument(
131
"--learning_rate",
132
type=float,
133
default=1e-4,
134
help="Initial learning rate (after the potential warmup period) to use.",
135
)
136
parser.add_argument(
137
"--scale_lr",
138
action="store_true",
139
default=True,
140
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
141
)
142
parser.add_argument(
143
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
144
)
145
parser.add_argument(
146
"--revision",
147
type=str,
148
default=None,
149
required=False,
150
help="Revision of pretrained model identifier from huggingface.co/models.",
151
)
152
parser.add_argument(
153
"--lr_scheduler",
154
type=str,
155
default="constant",
156
help=(
157
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
158
' "constant", "constant_with_warmup"]'
159
),
160
)
161
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
162
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
163
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
164
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
165
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
166
parser.add_argument(
167
"--use_auth_token",
168
action="store_true",
169
help=(
170
"Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
171
" private models)."
172
),
173
)
174
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
175
parser.add_argument(
176
"--hub_model_id",
177
type=str,
178
default=None,
179
help="The name of the repository to keep in sync with the local `output_dir`.",
180
)
181
parser.add_argument(
182
"--logging_dir",
183
type=str,
184
default="logs",
185
help=(
186
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
187
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
188
),
189
)
190
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
191
192
args = parser.parse_args()
193
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
194
if env_local_rank != -1 and env_local_rank != args.local_rank:
195
args.local_rank = env_local_rank
196
197
if args.train_data_dir is None:
198
raise ValueError("You must specify a train data directory.")
199
200
return args
201
202
203
imagenet_templates_small = [
204
"a photo of a {}",
205
"a rendering of a {}",
206
"a cropped photo of the {}",
207
"the photo of a {}",
208
"a photo of a clean {}",
209
"a photo of a dirty {}",
210
"a dark photo of the {}",
211
"a photo of my {}",
212
"a photo of the cool {}",
213
"a close-up photo of a {}",
214
"a bright photo of the {}",
215
"a cropped photo of a {}",
216
"a photo of the {}",
217
"a good photo of the {}",
218
"a photo of one {}",
219
"a close-up photo of the {}",
220
"a rendition of the {}",
221
"a photo of the clean {}",
222
"a rendition of a {}",
223
"a photo of a nice {}",
224
"a good photo of a {}",
225
"a photo of the nice {}",
226
"a photo of the small {}",
227
"a photo of the weird {}",
228
"a photo of the large {}",
229
"a photo of a cool {}",
230
"a photo of a small {}",
231
]
232
233
imagenet_style_templates_small = [
234
"a painting in the style of {}",
235
"a rendering in the style of {}",
236
"a cropped painting in the style of {}",
237
"the painting in the style of {}",
238
"a clean painting in the style of {}",
239
"a dirty painting in the style of {}",
240
"a dark painting in the style of {}",
241
"a picture in the style of {}",
242
"a cool painting in the style of {}",
243
"a close-up painting in the style of {}",
244
"a bright painting in the style of {}",
245
"a cropped painting in the style of {}",
246
"a good painting in the style of {}",
247
"a close-up painting in the style of {}",
248
"a rendition in the style of {}",
249
"a nice painting in the style of {}",
250
"a small painting in the style of {}",
251
"a weird painting in the style of {}",
252
"a large painting in the style of {}",
253
]
254
255
256
class TextualInversionDataset(Dataset):
257
def __init__(
258
self,
259
data_root,
260
tokenizer,
261
learnable_property="object", # [object, style]
262
size=512,
263
repeats=100,
264
interpolation="bicubic",
265
flip_p=0.5,
266
set="train",
267
placeholder_token="*",
268
center_crop=False,
269
):
270
self.data_root = data_root
271
self.tokenizer = tokenizer
272
self.learnable_property = learnable_property
273
self.size = size
274
self.placeholder_token = placeholder_token
275
self.center_crop = center_crop
276
self.flip_p = flip_p
277
278
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
279
280
self.num_images = len(self.image_paths)
281
self._length = self.num_images
282
283
if set == "train":
284
self._length = self.num_images * repeats
285
286
self.interpolation = {
287
"linear": PIL_INTERPOLATION["linear"],
288
"bilinear": PIL_INTERPOLATION["bilinear"],
289
"bicubic": PIL_INTERPOLATION["bicubic"],
290
"lanczos": PIL_INTERPOLATION["lanczos"],
291
}[interpolation]
292
293
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
294
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
295
296
def __len__(self):
297
return self._length
298
299
def __getitem__(self, i):
300
example = {}
301
image = Image.open(self.image_paths[i % self.num_images])
302
303
if not image.mode == "RGB":
304
image = image.convert("RGB")
305
306
placeholder_string = self.placeholder_token
307
text = random.choice(self.templates).format(placeholder_string)
308
309
example["input_ids"] = self.tokenizer(
310
text,
311
padding="max_length",
312
truncation=True,
313
max_length=self.tokenizer.model_max_length,
314
return_tensors="pt",
315
).input_ids[0]
316
317
# default to score-sde preprocessing
318
img = np.array(image).astype(np.uint8)
319
320
if self.center_crop:
321
crop = min(img.shape[0], img.shape[1])
322
(
323
h,
324
w,
325
) = (
326
img.shape[0],
327
img.shape[1],
328
)
329
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
330
331
image = Image.fromarray(img)
332
image = image.resize((self.size, self.size), resample=self.interpolation)
333
334
image = self.flip_transform(image)
335
image = np.array(image).astype(np.uint8)
336
image = (image / 127.5 - 1.0).astype(np.float32)
337
338
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
339
return example
340
341
342
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
343
if token is None:
344
token = HfFolder.get_token()
345
if organization is None:
346
username = whoami(token)["name"]
347
return f"{username}/{model_id}"
348
else:
349
return f"{organization}/{model_id}"
350
351
352
def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
353
if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
354
return
355
model.config.vocab_size = new_num_tokens
356
357
params = model.params
358
old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
359
old_num_tokens, emb_dim = old_embeddings.shape
360
361
initializer = jax.nn.initializers.normal()
362
363
new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
364
new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
365
new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
366
params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
367
368
model.params = params
369
return model
370
371
372
def get_params_to_save(params):
373
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
374
375
376
def main():
377
args = parse_args()
378
379
if args.seed is not None:
380
set_seed(args.seed)
381
382
if jax.process_index() == 0:
383
if args.push_to_hub:
384
if args.hub_model_id is None:
385
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
386
else:
387
repo_name = args.hub_model_id
388
create_repo(repo_name, exist_ok=True, token=args.hub_token)
389
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
390
391
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
392
if "step_*" not in gitignore:
393
gitignore.write("step_*\n")
394
if "epoch_*" not in gitignore:
395
gitignore.write("epoch_*\n")
396
elif args.output_dir is not None:
397
os.makedirs(args.output_dir, exist_ok=True)
398
399
# Make one log on every process with the configuration for debugging.
400
logging.basicConfig(
401
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
402
datefmt="%m/%d/%Y %H:%M:%S",
403
level=logging.INFO,
404
)
405
# Setup logging, we only want one process per machine to log things on the screen.
406
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
407
if jax.process_index() == 0:
408
transformers.utils.logging.set_verbosity_info()
409
else:
410
transformers.utils.logging.set_verbosity_error()
411
412
# Load the tokenizer and add the placeholder token as a additional special token
413
if args.tokenizer_name:
414
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
415
elif args.pretrained_model_name_or_path:
416
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
417
418
# Add the placeholder token in tokenizer
419
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
420
if num_added_tokens == 0:
421
raise ValueError(
422
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
423
" `placeholder_token` that is not already in the tokenizer."
424
)
425
426
# Convert the initializer_token, placeholder_token to ids
427
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
428
# Check if initializer_token is a single token or a sequence of tokens
429
if len(token_ids) > 1:
430
raise ValueError("The initializer token must be a single token.")
431
432
initializer_token_id = token_ids[0]
433
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
434
435
# Load models and create wrapper for stable diffusion
436
text_encoder = FlaxCLIPTextModel.from_pretrained(
437
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
438
)
439
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
440
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
441
)
442
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
443
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
444
)
445
446
# Create sampling rng
447
rng = jax.random.PRNGKey(args.seed)
448
rng, _ = jax.random.split(rng)
449
# Resize the token embeddings as we are adding new special tokens to the tokenizer
450
text_encoder = resize_token_embeddings(
451
text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
452
)
453
original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
454
455
train_dataset = TextualInversionDataset(
456
data_root=args.train_data_dir,
457
tokenizer=tokenizer,
458
size=args.resolution,
459
placeholder_token=args.placeholder_token,
460
repeats=args.repeats,
461
learnable_property=args.learnable_property,
462
center_crop=args.center_crop,
463
set="train",
464
)
465
466
def collate_fn(examples):
467
pixel_values = torch.stack([example["pixel_values"] for example in examples])
468
input_ids = torch.stack([example["input_ids"] for example in examples])
469
470
batch = {"pixel_values": pixel_values, "input_ids": input_ids}
471
batch = {k: v.numpy() for k, v in batch.items()}
472
473
return batch
474
475
total_train_batch_size = args.train_batch_size * jax.local_device_count()
476
train_dataloader = torch.utils.data.DataLoader(
477
train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
478
)
479
480
# Optimization
481
if args.scale_lr:
482
args.learning_rate = args.learning_rate * total_train_batch_size
483
484
constant_scheduler = optax.constant_schedule(args.learning_rate)
485
486
optimizer = optax.adamw(
487
learning_rate=constant_scheduler,
488
b1=args.adam_beta1,
489
b2=args.adam_beta2,
490
eps=args.adam_epsilon,
491
weight_decay=args.adam_weight_decay,
492
)
493
494
def create_mask(params, label_fn):
495
def _map(params, mask, label_fn):
496
for k in params:
497
if label_fn(k):
498
mask[k] = "token_embedding"
499
else:
500
if isinstance(params[k], dict):
501
mask[k] = {}
502
_map(params[k], mask[k], label_fn)
503
else:
504
mask[k] = "zero"
505
506
mask = {}
507
_map(params, mask, label_fn)
508
return mask
509
510
def zero_grads():
511
# from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
512
def init_fn(_):
513
return ()
514
515
def update_fn(updates, state, params=None):
516
return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
517
518
return optax.GradientTransformation(init_fn, update_fn)
519
520
# Zero out gradients of layers other than the token embedding layer
521
tx = optax.multi_transform(
522
{"token_embedding": optimizer, "zero": zero_grads()},
523
create_mask(text_encoder.params, lambda s: s == "token_embedding"),
524
)
525
526
state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
527
528
noise_scheduler = FlaxDDPMScheduler(
529
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
530
)
531
noise_scheduler_state = noise_scheduler.create_state()
532
533
# Initialize our training
534
train_rngs = jax.random.split(rng, jax.local_device_count())
535
536
# Define gradient train step fn
537
def train_step(state, vae_params, unet_params, batch, train_rng):
538
dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
539
540
def compute_loss(params):
541
vae_outputs = vae.apply(
542
{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
543
)
544
latents = vae_outputs.latent_dist.sample(sample_rng)
545
# (NHWC) -> (NCHW)
546
latents = jnp.transpose(latents, (0, 3, 1, 2))
547
latents = latents * vae.config.scaling_factor
548
549
noise_rng, timestep_rng = jax.random.split(sample_rng)
550
noise = jax.random.normal(noise_rng, latents.shape)
551
bsz = latents.shape[0]
552
timesteps = jax.random.randint(
553
timestep_rng,
554
(bsz,),
555
0,
556
noise_scheduler.config.num_train_timesteps,
557
)
558
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
559
encoder_hidden_states = state.apply_fn(
560
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
561
)[0]
562
# Predict the noise residual and compute loss
563
model_pred = unet.apply(
564
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
565
).sample
566
567
# Get the target for loss depending on the prediction type
568
if noise_scheduler.config.prediction_type == "epsilon":
569
target = noise
570
elif noise_scheduler.config.prediction_type == "v_prediction":
571
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
572
else:
573
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
574
575
loss = (target - model_pred) ** 2
576
loss = loss.mean()
577
578
return loss
579
580
grad_fn = jax.value_and_grad(compute_loss)
581
loss, grad = grad_fn(state.params)
582
grad = jax.lax.pmean(grad, "batch")
583
new_state = state.apply_gradients(grads=grad)
584
585
# Keep the token embeddings fixed except the newly added embeddings for the concept,
586
# as we only want to optimize the concept embeddings
587
token_embeds = original_token_embeds.at[placeholder_token_id].set(
588
new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
589
)
590
new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
591
592
metrics = {"loss": loss}
593
metrics = jax.lax.pmean(metrics, axis_name="batch")
594
return new_state, metrics, new_train_rng
595
596
# Create parallel version of the train and eval step
597
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
598
599
# Replicate the train state on each device
600
state = jax_utils.replicate(state)
601
vae_params = jax_utils.replicate(vae_params)
602
unet_params = jax_utils.replicate(unet_params)
603
604
# Train!
605
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
606
607
# Scheduler and math around the number of training steps.
608
if args.max_train_steps is None:
609
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
610
611
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
612
613
logger.info("***** Running training *****")
614
logger.info(f" Num examples = {len(train_dataset)}")
615
logger.info(f" Num Epochs = {args.num_train_epochs}")
616
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
617
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
618
logger.info(f" Total optimization steps = {args.max_train_steps}")
619
620
global_step = 0
621
622
epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
623
for epoch in epochs:
624
# ======================== Training ================================
625
626
train_metrics = []
627
628
steps_per_epoch = len(train_dataset) // total_train_batch_size
629
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
630
# train
631
for batch in train_dataloader:
632
batch = shard(batch)
633
state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
634
train_metrics.append(train_metric)
635
636
train_step_progress_bar.update(1)
637
global_step += 1
638
639
if global_step >= args.max_train_steps:
640
break
641
if global_step % args.save_steps == 0:
642
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
643
"embedding"
644
][placeholder_token_id]
645
learned_embeds_dict = {args.placeholder_token: learned_embeds}
646
jnp.save(
647
os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
648
)
649
650
train_metric = jax_utils.unreplicate(train_metric)
651
652
train_step_progress_bar.close()
653
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
654
655
# Create the pipeline using using the trained modules and save it.
656
if jax.process_index() == 0:
657
scheduler = FlaxPNDMScheduler(
658
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
659
)
660
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
661
"CompVis/stable-diffusion-safety-checker", from_pt=True
662
)
663
pipeline = FlaxStableDiffusionPipeline(
664
text_encoder=text_encoder,
665
vae=vae,
666
unet=unet,
667
tokenizer=tokenizer,
668
scheduler=scheduler,
669
safety_checker=safety_checker,
670
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
671
)
672
673
pipeline.save_pretrained(
674
args.output_dir,
675
params={
676
"text_encoder": get_params_to_save(state.params),
677
"vae": get_params_to_save(vae_params),
678
"unet": get_params_to_save(unet_params),
679
"safety_checker": safety_checker.params,
680
},
681
)
682
683
# Also save the newly trained embeddings
684
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
685
placeholder_token_id
686
]
687
learned_embeds_dict = {args.placeholder_token: learned_embeds}
688
jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
689
690
if args.push_to_hub:
691
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
692
693
694
if __name__ == "__main__":
695
main()
696
697