Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/examples/idefics/finetune_image_captioning.py
Views: 2542
# adapted from https://github.com/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/image_captioning.ipynb12# This example demonstrates normal finetuning (w/o peft) - for the sake of keeping the memory3# requirements small it freezes the original pre-trained text and image layers to keep the memory4# requirements to just 40GB. If you have multiple GPUs then you can remove the unfreeze part to5# finetune the whole model. Alternatively use the PEFT solution as shown in6# IDEFICS_finetuning_demo.ipynb notebook which requires only 20GB to finetune the whole model.78import torch9import torchvision.transforms as transforms1011from datasets import load_dataset12from PIL import Image13from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments1415device = "cuda" if torch.cuda.is_available() else "cpu"1617checkpoint = "HuggingFaceM4/idefics-9b"18# checkpoint = "HuggingFaceM4/tiny-random-idefics"1920processor = AutoProcessor.from_pretrained(checkpoint)21model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).to(device)2223# freeze the original text and vision models and finetune only the layers added by IDEFICS24# you can unfreeze the whole model, but it'll require multiple gpus to finetune25model.model.freeze_text_layers()26model.model.freeze_vision_layers()2728# help util29def check_inference():30url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"31prompts = [32url,33"Question: What's on the picture? Answer:",34]3536inputs = processor(prompts, return_tensors="pt").to(device)37generated_ids = model.generate(**inputs, max_length=150)38generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]39print(generated_text)4041# check generation before finetuning42check_inference()43# well, actually it looks like the model is already aware of pokemon - but this dataset will refine it further4445# finetune the model on the pokemon types dataset46ds = load_dataset("GabeHD/pokemon-type-captions")47ds = ds["train"].train_test_split(test_size=0.1)48train_ds = ds["train"]49eval_ds = ds["test"]5051def convert_to_rgb(image):52# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background53# for transparent images. The call to `alpha_composite` handles this case54if image.mode == "RGB":55return image5657image_rgba = image.convert("RGBA")58background = Image.new("RGBA", image_rgba.size, (255, 255, 255))59alpha_composite = Image.alpha_composite(background, image_rgba)60alpha_composite = alpha_composite.convert("RGB")61return alpha_composite6263def ds_transforms(example_batch):64image_size = processor.image_processor.image_size65image_mean = processor.image_processor.image_mean66image_std = processor.image_processor.image_std6768image_transform = transforms.Compose([69convert_to_rgb,70transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),71transforms.ToTensor(),72transforms.Normalize(mean=image_mean, std=image_std),73])7475prompts = []76for i in range(len(example_batch)):77prompts.append(78[79example_batch["image"][i],80f"Question: What's on the picture? Answer: {example_batch['text'][i]}\n",81],82)8384inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)8586inputs["labels"] = inputs["input_ids"]8788return inputs8990train_ds.set_transform(ds_transforms)91eval_ds.set_transform(ds_transforms)9293model_name = checkpoint.split("/")[1]9495# this setup requires about 40GB of gpu memory96training_args = TrainingArguments(97output_dir=f"{model_name}-pokemon",98learning_rate=5e-6,99num_train_epochs=10,100bf16=True,101per_device_train_batch_size=32,102per_device_eval_batch_size=32,103gradient_accumulation_steps=2,104dataloader_pin_memory=False,105save_total_limit=3,106evaluation_strategy="steps",107save_strategy="steps",108save_steps=1000, # don't save until ready...109eval_steps=40,110logging_steps=40,111remove_unused_columns=False,112push_to_hub=False,113label_names=["labels"],114load_best_model_at_end=True,115report_to=None,116)117118trainer = Trainer(119model=model,120args=training_args,121train_dataset=train_ds,122eval_dataset=eval_ds,123)124125trainer.train()126127# check generation again after finetuning128check_inference()129130# after finetuning ideally we want generate to produce something like: a drawing of a pink and blue pokemon131132133