Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/text_to_image/train_text_to_image_flax.py
1448 views
1
import argparse
2
import logging
3
import math
4
import os
5
import random
6
from pathlib import Path
7
from typing import Optional
8
9
import jax
10
import jax.numpy as jnp
11
import numpy as np
12
import optax
13
import torch
14
import torch.utils.checkpoint
15
import transformers
16
from datasets import load_dataset
17
from flax import jax_utils
18
from flax.training import train_state
19
from flax.training.common_utils import shard
20
from huggingface_hub import HfFolder, Repository, create_repo, whoami
21
from torchvision import transforms
22
from tqdm.auto import tqdm
23
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
24
25
from diffusers import (
26
FlaxAutoencoderKL,
27
FlaxDDPMScheduler,
28
FlaxPNDMScheduler,
29
FlaxStableDiffusionPipeline,
30
FlaxUNet2DConditionModel,
31
)
32
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
33
from diffusers.utils import check_min_version
34
35
36
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
37
check_min_version("0.15.0.dev0")
38
39
logger = logging.getLogger(__name__)
40
41
42
def parse_args():
43
parser = argparse.ArgumentParser(description="Simple example of a training script.")
44
parser.add_argument(
45
"--pretrained_model_name_or_path",
46
type=str,
47
default=None,
48
required=True,
49
help="Path to pretrained model or model identifier from huggingface.co/models.",
50
)
51
parser.add_argument(
52
"--revision",
53
type=str,
54
default=None,
55
required=False,
56
help="Revision of pretrained model identifier from huggingface.co/models.",
57
)
58
parser.add_argument(
59
"--dataset_name",
60
type=str,
61
default=None,
62
help=(
63
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
64
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
65
" or to a folder containing files that 🤗 Datasets can understand."
66
),
67
)
68
parser.add_argument(
69
"--dataset_config_name",
70
type=str,
71
default=None,
72
help="The config of the Dataset, leave as None if there's only one config.",
73
)
74
parser.add_argument(
75
"--train_data_dir",
76
type=str,
77
default=None,
78
help=(
79
"A folder containing the training data. Folder contents must follow the structure described in"
80
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
81
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
82
),
83
)
84
parser.add_argument(
85
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
86
)
87
parser.add_argument(
88
"--caption_column",
89
type=str,
90
default="text",
91
help="The column of the dataset containing a caption or a list of captions.",
92
)
93
parser.add_argument(
94
"--max_train_samples",
95
type=int,
96
default=None,
97
help=(
98
"For debugging purposes or quicker training, truncate the number of training examples to this "
99
"value if set."
100
),
101
)
102
parser.add_argument(
103
"--output_dir",
104
type=str,
105
default="sd-model-finetuned",
106
help="The output directory where the model predictions and checkpoints will be written.",
107
)
108
parser.add_argument(
109
"--cache_dir",
110
type=str,
111
default=None,
112
help="The directory where the downloaded models and datasets will be stored.",
113
)
114
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
115
parser.add_argument(
116
"--resolution",
117
type=int,
118
default=512,
119
help=(
120
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
121
" resolution"
122
),
123
)
124
parser.add_argument(
125
"--center_crop",
126
default=False,
127
action="store_true",
128
help=(
129
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
130
" cropped. The images will be resized to the resolution first before cropping."
131
),
132
)
133
parser.add_argument(
134
"--random_flip",
135
action="store_true",
136
help="whether to randomly flip images horizontally",
137
)
138
parser.add_argument(
139
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
140
)
141
parser.add_argument("--num_train_epochs", type=int, default=100)
142
parser.add_argument(
143
"--max_train_steps",
144
type=int,
145
default=None,
146
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
147
)
148
parser.add_argument(
149
"--learning_rate",
150
type=float,
151
default=1e-4,
152
help="Initial learning rate (after the potential warmup period) to use.",
153
)
154
parser.add_argument(
155
"--scale_lr",
156
action="store_true",
157
default=False,
158
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
159
)
160
parser.add_argument(
161
"--lr_scheduler",
162
type=str,
163
default="constant",
164
help=(
165
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
166
' "constant", "constant_with_warmup"]'
167
),
168
)
169
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
170
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
171
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
172
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
173
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
174
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
175
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
176
parser.add_argument(
177
"--hub_model_id",
178
type=str,
179
default=None,
180
help="The name of the repository to keep in sync with the local `output_dir`.",
181
)
182
parser.add_argument(
183
"--logging_dir",
184
type=str,
185
default="logs",
186
help=(
187
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
188
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
189
),
190
)
191
parser.add_argument(
192
"--report_to",
193
type=str,
194
default="tensorboard",
195
help=(
196
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
197
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
198
),
199
)
200
parser.add_argument(
201
"--mixed_precision",
202
type=str,
203
default="no",
204
choices=["no", "fp16", "bf16"],
205
help=(
206
"Whether to use mixed precision. Choose"
207
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
208
"and an Nvidia Ampere GPU."
209
),
210
)
211
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
212
213
args = parser.parse_args()
214
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
215
if env_local_rank != -1 and env_local_rank != args.local_rank:
216
args.local_rank = env_local_rank
217
218
# Sanity checks
219
if args.dataset_name is None and args.train_data_dir is None:
220
raise ValueError("Need either a dataset name or a training folder.")
221
222
return args
223
224
225
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
226
if token is None:
227
token = HfFolder.get_token()
228
if organization is None:
229
username = whoami(token)["name"]
230
return f"{username}/{model_id}"
231
else:
232
return f"{organization}/{model_id}"
233
234
235
dataset_name_mapping = {
236
"lambdalabs/pokemon-blip-captions": ("image", "text"),
237
}
238
239
240
def get_params_to_save(params):
241
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
242
243
244
def main():
245
args = parse_args()
246
247
logging.basicConfig(
248
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
249
datefmt="%m/%d/%Y %H:%M:%S",
250
level=logging.INFO,
251
)
252
# Setup logging, we only want one process per machine to log things on the screen.
253
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
254
if jax.process_index() == 0:
255
transformers.utils.logging.set_verbosity_info()
256
else:
257
transformers.utils.logging.set_verbosity_error()
258
259
if args.seed is not None:
260
set_seed(args.seed)
261
262
# Handle the repository creation
263
if jax.process_index() == 0:
264
if args.push_to_hub:
265
if args.hub_model_id is None:
266
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
267
else:
268
repo_name = args.hub_model_id
269
create_repo(repo_name, exist_ok=True, token=args.hub_token)
270
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
271
272
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
273
if "step_*" not in gitignore:
274
gitignore.write("step_*\n")
275
if "epoch_*" not in gitignore:
276
gitignore.write("epoch_*\n")
277
elif args.output_dir is not None:
278
os.makedirs(args.output_dir, exist_ok=True)
279
280
# Get the datasets: you can either provide your own training and evaluation files (see below)
281
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
282
283
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
284
# download the dataset.
285
if args.dataset_name is not None:
286
# Downloading and loading a dataset from the hub.
287
dataset = load_dataset(
288
args.dataset_name,
289
args.dataset_config_name,
290
cache_dir=args.cache_dir,
291
)
292
else:
293
data_files = {}
294
if args.train_data_dir is not None:
295
data_files["train"] = os.path.join(args.train_data_dir, "**")
296
dataset = load_dataset(
297
"imagefolder",
298
data_files=data_files,
299
cache_dir=args.cache_dir,
300
)
301
# See more about loading custom images at
302
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
303
304
# Preprocessing the datasets.
305
# We need to tokenize inputs and targets.
306
column_names = dataset["train"].column_names
307
308
# 6. Get the column names for input/target.
309
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
310
if args.image_column is None:
311
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
312
else:
313
image_column = args.image_column
314
if image_column not in column_names:
315
raise ValueError(
316
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
317
)
318
if args.caption_column is None:
319
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
320
else:
321
caption_column = args.caption_column
322
if caption_column not in column_names:
323
raise ValueError(
324
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
325
)
326
327
# Preprocessing the datasets.
328
# We need to tokenize input captions and transform the images.
329
def tokenize_captions(examples, is_train=True):
330
captions = []
331
for caption in examples[caption_column]:
332
if isinstance(caption, str):
333
captions.append(caption)
334
elif isinstance(caption, (list, np.ndarray)):
335
# take a random caption if there are multiple
336
captions.append(random.choice(caption) if is_train else caption[0])
337
else:
338
raise ValueError(
339
f"Caption column `{caption_column}` should contain either strings or lists of strings."
340
)
341
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
342
input_ids = inputs.input_ids
343
return input_ids
344
345
train_transforms = transforms.Compose(
346
[
347
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
348
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
349
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
350
transforms.ToTensor(),
351
transforms.Normalize([0.5], [0.5]),
352
]
353
)
354
355
def preprocess_train(examples):
356
images = [image.convert("RGB") for image in examples[image_column]]
357
examples["pixel_values"] = [train_transforms(image) for image in images]
358
examples["input_ids"] = tokenize_captions(examples)
359
360
return examples
361
362
if jax.process_index() == 0:
363
if args.max_train_samples is not None:
364
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
365
# Set the training transforms
366
train_dataset = dataset["train"].with_transform(preprocess_train)
367
368
def collate_fn(examples):
369
pixel_values = torch.stack([example["pixel_values"] for example in examples])
370
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
371
input_ids = [example["input_ids"] for example in examples]
372
373
padded_tokens = tokenizer.pad(
374
{"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
375
)
376
batch = {
377
"pixel_values": pixel_values,
378
"input_ids": padded_tokens.input_ids,
379
}
380
batch = {k: v.numpy() for k, v in batch.items()}
381
382
return batch
383
384
total_train_batch_size = args.train_batch_size * jax.local_device_count()
385
train_dataloader = torch.utils.data.DataLoader(
386
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
387
)
388
389
weight_dtype = jnp.float32
390
if args.mixed_precision == "fp16":
391
weight_dtype = jnp.float16
392
elif args.mixed_precision == "bf16":
393
weight_dtype = jnp.bfloat16
394
395
# Load models and create wrapper for stable diffusion
396
tokenizer = CLIPTokenizer.from_pretrained(
397
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
398
)
399
text_encoder = FlaxCLIPTextModel.from_pretrained(
400
args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
401
)
402
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
403
args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
404
)
405
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
406
args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
407
)
408
409
# Optimization
410
if args.scale_lr:
411
args.learning_rate = args.learning_rate * total_train_batch_size
412
413
constant_scheduler = optax.constant_schedule(args.learning_rate)
414
415
adamw = optax.adamw(
416
learning_rate=constant_scheduler,
417
b1=args.adam_beta1,
418
b2=args.adam_beta2,
419
eps=args.adam_epsilon,
420
weight_decay=args.adam_weight_decay,
421
)
422
423
optimizer = optax.chain(
424
optax.clip_by_global_norm(args.max_grad_norm),
425
adamw,
426
)
427
428
state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
429
430
noise_scheduler = FlaxDDPMScheduler(
431
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
432
)
433
noise_scheduler_state = noise_scheduler.create_state()
434
435
# Initialize our training
436
rng = jax.random.PRNGKey(args.seed)
437
train_rngs = jax.random.split(rng, jax.local_device_count())
438
439
def train_step(state, text_encoder_params, vae_params, batch, train_rng):
440
dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
441
442
def compute_loss(params):
443
# Convert images to latent space
444
vae_outputs = vae.apply(
445
{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
446
)
447
latents = vae_outputs.latent_dist.sample(sample_rng)
448
# (NHWC) -> (NCHW)
449
latents = jnp.transpose(latents, (0, 3, 1, 2))
450
latents = latents * vae.config.scaling_factor
451
452
# Sample noise that we'll add to the latents
453
noise_rng, timestep_rng = jax.random.split(sample_rng)
454
noise = jax.random.normal(noise_rng, latents.shape)
455
# Sample a random timestep for each image
456
bsz = latents.shape[0]
457
timesteps = jax.random.randint(
458
timestep_rng,
459
(bsz,),
460
0,
461
noise_scheduler.config.num_train_timesteps,
462
)
463
464
# Add noise to the latents according to the noise magnitude at each timestep
465
# (this is the forward diffusion process)
466
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
467
468
# Get the text embedding for conditioning
469
encoder_hidden_states = text_encoder(
470
batch["input_ids"],
471
params=text_encoder_params,
472
train=False,
473
)[0]
474
475
# Predict the noise residual and compute loss
476
model_pred = unet.apply(
477
{"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
478
).sample
479
480
# Get the target for loss depending on the prediction type
481
if noise_scheduler.config.prediction_type == "epsilon":
482
target = noise
483
elif noise_scheduler.config.prediction_type == "v_prediction":
484
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
485
else:
486
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
487
488
loss = (target - model_pred) ** 2
489
loss = loss.mean()
490
491
return loss
492
493
grad_fn = jax.value_and_grad(compute_loss)
494
loss, grad = grad_fn(state.params)
495
grad = jax.lax.pmean(grad, "batch")
496
497
new_state = state.apply_gradients(grads=grad)
498
499
metrics = {"loss": loss}
500
metrics = jax.lax.pmean(metrics, axis_name="batch")
501
502
return new_state, metrics, new_train_rng
503
504
# Create parallel version of the train step
505
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
506
507
# Replicate the train state on each device
508
state = jax_utils.replicate(state)
509
text_encoder_params = jax_utils.replicate(text_encoder.params)
510
vae_params = jax_utils.replicate(vae_params)
511
512
# Train!
513
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
514
515
# Scheduler and math around the number of training steps.
516
if args.max_train_steps is None:
517
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
518
519
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
520
521
logger.info("***** Running training *****")
522
logger.info(f" Num examples = {len(train_dataset)}")
523
logger.info(f" Num Epochs = {args.num_train_epochs}")
524
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
525
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
526
logger.info(f" Total optimization steps = {args.max_train_steps}")
527
528
global_step = 0
529
530
epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
531
for epoch in epochs:
532
# ======================== Training ================================
533
534
train_metrics = []
535
536
steps_per_epoch = len(train_dataset) // total_train_batch_size
537
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
538
# train
539
for batch in train_dataloader:
540
batch = shard(batch)
541
state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs)
542
train_metrics.append(train_metric)
543
544
train_step_progress_bar.update(1)
545
546
global_step += 1
547
if global_step >= args.max_train_steps:
548
break
549
550
train_metric = jax_utils.unreplicate(train_metric)
551
552
train_step_progress_bar.close()
553
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
554
555
# Create the pipeline using using the trained modules and save it.
556
if jax.process_index() == 0:
557
scheduler = FlaxPNDMScheduler(
558
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
559
)
560
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
561
"CompVis/stable-diffusion-safety-checker", from_pt=True
562
)
563
pipeline = FlaxStableDiffusionPipeline(
564
text_encoder=text_encoder,
565
vae=vae,
566
unet=unet,
567
tokenizer=tokenizer,
568
scheduler=scheduler,
569
safety_checker=safety_checker,
570
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
571
)
572
573
pipeline.save_pretrained(
574
args.output_dir,
575
params={
576
"text_encoder": get_params_to_save(text_encoder_params),
577
"vae": get_params_to_save(vae_params),
578
"unet": get_params_to_save(state.params),
579
"safety_checker": safety_checker.params,
580
},
581
)
582
583
if args.push_to_hub:
584
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
585
586
587
if __name__ == "__main__":
588
main()
589
590