Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/imagic/train_imagic.py
1448 views
1
import argparse
2
import math
3
import os
4
from pathlib import Path
5
from typing import Optional
6
7
import torch
8
import torch.nn.functional as F
9
import torch.utils.checkpoint
10
11
from accelerate import Accelerator
12
from accelerate.logging import get_logger
13
from accelerate.utils import set_seed
14
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
15
from huggingface_hub import HfFolder, Repository, whoami
16
from PIL import Image
17
import numpy as np
18
from torchvision import transforms
19
from tqdm.auto import tqdm
20
from transformers import CLIPTextModel, CLIPTokenizer
21
22
23
logger = get_logger(__name__)
24
25
26
def parse_args():
27
parser = argparse.ArgumentParser(description="Simple example of a training script.")
28
parser.add_argument(
29
"--pretrained_model_name_or_path",
30
type=str,
31
default=None,
32
required=True,
33
help="Path to pretrained model or model identifier from huggingface.co/models.",
34
)
35
parser.add_argument(
36
"--tokenizer_name",
37
type=str,
38
default=None,
39
help="Pretrained tokenizer name or path if not the same as model_name",
40
)
41
parser.add_argument(
42
"--input_image",
43
type=str,
44
default=None,
45
required=True,
46
help="Path to input image to edit.",
47
)
48
parser.add_argument(
49
"--target_text",
50
type=str,
51
default=None,
52
help="The target text describing the output image.",
53
)
54
parser.add_argument(
55
"--output_dir",
56
type=str,
57
default="text-inversion-model",
58
help="The output directory where the model predictions and checkpoints will be written.",
59
)
60
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
61
parser.add_argument(
62
"--resolution",
63
type=int,
64
default=512,
65
help=(
66
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
67
" resolution"
68
),
69
)
70
parser.add_argument(
71
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
72
)
73
parser.add_argument(
74
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
75
)
76
parser.add_argument(
77
"--emb_train_steps",
78
type=int,
79
default=500,
80
help="Total number of training steps to perform.",
81
)
82
parser.add_argument(
83
"--max_train_steps",
84
type=int,
85
default=1000,
86
help="Total number of training steps to perform.",
87
)
88
parser.add_argument(
89
"--gradient_accumulation_steps",
90
type=int,
91
default=1,
92
help="Number of updates steps to accumulate before performing a backward/update pass.",
93
)
94
parser.add_argument(
95
"--gradient_checkpointing",
96
action="store_true",
97
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
98
)
99
parser.add_argument(
100
"--emb_learning_rate",
101
type=float,
102
default=1e-3,
103
help="Learning rate for optimizing the embeddings.",
104
)
105
parser.add_argument(
106
"--learning_rate",
107
type=float,
108
default=1e-6,
109
help="Learning rate for fine tuning the model.",
110
)
111
parser.add_argument(
112
"--scale_lr",
113
action="store_true",
114
default=False,
115
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
116
)
117
parser.add_argument(
118
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
119
)
120
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
121
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
122
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
123
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
124
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
125
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
126
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
127
parser.add_argument(
128
"--hub_model_id",
129
type=str,
130
default=None,
131
help="The name of the repository to keep in sync with the local `output_dir`.",
132
)
133
parser.add_argument(
134
"--logging_dir",
135
type=str,
136
default="logs",
137
help=(
138
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
139
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
140
),
141
)
142
parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
143
parser.add_argument(
144
"--mixed_precision",
145
type=str,
146
default="no",
147
choices=["no", "fp16", "bf16"],
148
help=(
149
"Whether to use mixed precision. Choose"
150
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
151
"and an Nvidia Ampere GPU."
152
),
153
)
154
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
155
156
args = parser.parse_args()
157
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
158
if env_local_rank != -1 and env_local_rank != args.local_rank:
159
args.local_rank = env_local_rank
160
161
return args
162
163
164
class AverageMeter:
165
def __init__(self, name=None):
166
self.name = name
167
self.reset()
168
169
def reset(self):
170
self.sum = self.count = self.avg = 0
171
172
def update(self, val, n=1):
173
self.sum += val * n
174
self.count += n
175
self.avg = self.sum / self.count
176
177
178
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
179
if token is None:
180
token = HfFolder.get_token()
181
if organization is None:
182
username = whoami(token)["name"]
183
return f"{username}/{model_id}"
184
else:
185
return f"{organization}/{model_id}"
186
187
188
def main():
189
args = parse_args()
190
logging_dir = Path(args.output_dir, args.logging_dir)
191
192
accelerator = Accelerator(
193
gradient_accumulation_steps=args.gradient_accumulation_steps,
194
mixed_precision=args.mixed_precision,
195
log_with="tensorboard",
196
logging_dir=logging_dir,
197
)
198
199
if args.seed is not None:
200
set_seed(args.seed)
201
202
# Handle the repository creation
203
if accelerator.is_main_process:
204
if args.push_to_hub:
205
if args.hub_model_id is None:
206
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
207
else:
208
repo_name = args.hub_model_id
209
repo = Repository(args.output_dir, clone_from=repo_name)
210
211
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
212
if "step_*" not in gitignore:
213
gitignore.write("step_*\n")
214
if "epoch_*" not in gitignore:
215
gitignore.write("epoch_*\n")
216
elif args.output_dir is not None:
217
os.makedirs(args.output_dir, exist_ok=True)
218
219
# Load the tokenizer
220
if args.tokenizer_name:
221
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
222
elif args.pretrained_model_name_or_path:
223
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=True)
224
225
# Load models and create wrapper for stable diffusion
226
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=True)
227
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=True)
228
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=True)
229
230
if args.gradient_checkpointing:
231
unet.enable_gradient_checkpointing()
232
233
if args.scale_lr:
234
args.learning_rate = (
235
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
236
)
237
238
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
239
if args.use_8bit_adam:
240
try:
241
import bitsandbytes as bnb
242
except ImportError:
243
raise ImportError(
244
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
245
)
246
247
optimizer_class = bnb.optim.Adam8bit
248
else:
249
optimizer_class = torch.optim.Adam
250
251
noise_scheduler = DDPMScheduler(
252
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
253
)
254
255
weight_dtype = torch.float32
256
if args.mixed_precision == "fp16":
257
weight_dtype = torch.float16
258
elif args.mixed_precision == "bf16":
259
weight_dtype = torch.bfloat16
260
261
# Move text_encode and vae to gpu.
262
# For mixed precision training we cast the text_encoder and vae weights to half-precision
263
# as these models are only used for inference, keeping weights in full precision is not required.
264
text_encoder.to(accelerator.device, dtype=weight_dtype)
265
vae.to(accelerator.device, dtype=weight_dtype)
266
267
# Encode the input image.
268
input_image = Image.open(args.input_image).convert("RGB")
269
270
image_transforms = transforms.Compose(
271
[
272
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
273
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
274
transforms.ToTensor(),
275
transforms.Normalize([0.5], [0.5]),
276
]
277
)
278
279
init_image = image_transforms(input_image)
280
init_image = init_image[None].to(device=accelerator.device, dtype=weight_dtype)
281
with torch.inference_mode():
282
init_latents = vae.encode(init_image).latent_dist.sample()
283
init_latents = 0.18215 * init_latents
284
285
# Encode the target text.
286
text_ids = tokenizer(
287
args.target_text,
288
padding="max_length",
289
truncation=True,
290
max_length=tokenizer.model_max_length,
291
return_tensors="pt",
292
).input_ids
293
294
text_ids = text_ids.to(device=accelerator.device)
295
with torch.inference_mode():
296
target_embeddings = text_encoder(text_ids)[0]
297
298
del vae, text_encoder
299
if torch.cuda.is_available():
300
torch.cuda.empty_cache()
301
302
target_embeddings = target_embeddings.float()
303
optimized_embeddings = target_embeddings.clone()
304
305
# Optimize the text embeddings first.
306
optimized_embeddings.requires_grad_(True)
307
optimizer = optimizer_class(
308
[optimized_embeddings], # only optimize embeddings
309
lr=args.emb_learning_rate,
310
betas=(args.adam_beta1, args.adam_beta2),
311
# weight_decay=args.adam_weight_decay,
312
eps=args.adam_epsilon,
313
)
314
315
unet, optimizer = accelerator.prepare(unet, optimizer)
316
317
# We need to initialize the trackers we use, and also store our configuration.
318
# The trackers initializes automatically on the main process.
319
if accelerator.is_main_process:
320
accelerator.init_trackers("imagic", config=vars(args))
321
322
def train_loop(pbar, optimizer, params):
323
loss_avg = AverageMeter()
324
for step in pbar:
325
with accelerator.accumulate(unet):
326
noise = torch.randn_like(init_latents)
327
bsz = init_latents.shape[0]
328
# Sample a random timestep for each image
329
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latents.device)
330
timesteps = timesteps.long()
331
332
# Add noise to the latents according to the noise magnitude at each timestep
333
# (this is the forward diffusion process)
334
noisy_latents = noise_scheduler.add_noise(init_latents, noise, timesteps)
335
336
noise_pred = unet(noisy_latents, timesteps, optimized_embeddings).sample
337
338
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
339
340
accelerator.backward(loss)
341
# if accelerator.sync_gradients: # results aren't good with it, may be will need more training with it.
342
# accelerator.clip_grad_norm_(params, args.max_grad_norm)
343
optimizer.step()
344
optimizer.zero_grad(set_to_none=True)
345
loss_avg.update(loss.detach_(), bsz)
346
347
if not step % args.log_interval:
348
logs = {"loss": loss_avg.avg.item()}
349
progress_bar.set_postfix(**logs)
350
accelerator.log(logs, step=step)
351
352
accelerator.wait_for_everyone()
353
354
progress_bar = tqdm(range(args.emb_train_steps), disable=not accelerator.is_local_main_process)
355
progress_bar.set_description("Optimizing embedding")
356
357
train_loop(progress_bar, optimizer, optimized_embeddings)
358
359
optimized_embeddings.requires_grad_(False)
360
if accelerator.is_main_process:
361
torch.save(target_embeddings.cpu(), os.path.join(args.output_dir, "target_embeddings.pt"))
362
torch.save(optimized_embeddings.cpu(), os.path.join(args.output_dir, "optimized_embeddings.pt"))
363
with open(os.path.join(args.output_dir, "target_text.txt"), "w") as f:
364
f.write(args.target_text)
365
366
# Fine tune the diffusion model.
367
optimizer = optimizer_class(
368
accelerator.unwrap_model(unet).parameters(),
369
lr=args.learning_rate,
370
betas=(args.adam_beta1, args.adam_beta2),
371
# weight_decay=args.adam_weight_decay,
372
eps=args.adam_epsilon,
373
)
374
optimizer = accelerator.prepare(optimizer)
375
376
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
377
progress_bar.set_description("Fine Tuning")
378
unet.train()
379
380
train_loop(progress_bar, optimizer, unet.parameters())
381
382
# Create the pipeline using using the trained modules and save it.
383
if accelerator.is_main_process:
384
pipeline = StableDiffusionPipeline.from_pretrained(
385
args.pretrained_model_name_or_path,
386
unet=accelerator.unwrap_model(unet),
387
use_auth_token=True
388
)
389
pipeline.save_pretrained(args.output_dir)
390
391
if args.push_to_hub:
392
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
393
394
accelerator.end_training()
395
396
397
if __name__ == "__main__":
398
main()
399
400