CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/idefics/finetune_image_captioning.py
Views: 2542
1
# adapted from https://github.com/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/image_captioning.ipynb
2
3
# This example demonstrates normal finetuning (w/o peft) - for the sake of keeping the memory
4
# requirements small it freezes the original pre-trained text and image layers to keep the memory
5
# requirements to just 40GB. If you have multiple GPUs then you can remove the unfreeze part to
6
# finetune the whole model. Alternatively use the PEFT solution as shown in
7
# IDEFICS_finetuning_demo.ipynb notebook which requires only 20GB to finetune the whole model.
8
9
import torch
10
import torchvision.transforms as transforms
11
12
from datasets import load_dataset
13
from PIL import Image
14
from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments
15
16
device = "cuda" if torch.cuda.is_available() else "cpu"
17
18
checkpoint = "HuggingFaceM4/idefics-9b"
19
# checkpoint = "HuggingFaceM4/tiny-random-idefics"
20
21
processor = AutoProcessor.from_pretrained(checkpoint)
22
model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).to(device)
23
24
# freeze the original text and vision models and finetune only the layers added by IDEFICS
25
# you can unfreeze the whole model, but it'll require multiple gpus to finetune
26
model.model.freeze_text_layers()
27
model.model.freeze_vision_layers()
28
29
# help util
30
def check_inference():
31
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
32
prompts = [
33
url,
34
"Question: What's on the picture? Answer:",
35
]
36
37
inputs = processor(prompts, return_tensors="pt").to(device)
38
generated_ids = model.generate(**inputs, max_length=150)
39
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
print(generated_text)
41
42
# check generation before finetuning
43
check_inference()
44
# well, actually it looks like the model is already aware of pokemon - but this dataset will refine it further
45
46
# finetune the model on the pokemon types dataset
47
ds = load_dataset("GabeHD/pokemon-type-captions")
48
ds = ds["train"].train_test_split(test_size=0.1)
49
train_ds = ds["train"]
50
eval_ds = ds["test"]
51
52
def convert_to_rgb(image):
53
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
54
# for transparent images. The call to `alpha_composite` handles this case
55
if image.mode == "RGB":
56
return image
57
58
image_rgba = image.convert("RGBA")
59
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
60
alpha_composite = Image.alpha_composite(background, image_rgba)
61
alpha_composite = alpha_composite.convert("RGB")
62
return alpha_composite
63
64
def ds_transforms(example_batch):
65
image_size = processor.image_processor.image_size
66
image_mean = processor.image_processor.image_mean
67
image_std = processor.image_processor.image_std
68
69
image_transform = transforms.Compose([
70
convert_to_rgb,
71
transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
72
transforms.ToTensor(),
73
transforms.Normalize(mean=image_mean, std=image_std),
74
])
75
76
prompts = []
77
for i in range(len(example_batch)):
78
prompts.append(
79
[
80
example_batch["image"][i],
81
f"Question: What's on the picture? Answer: {example_batch['text'][i]}\n",
82
],
83
)
84
85
inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)
86
87
inputs["labels"] = inputs["input_ids"]
88
89
return inputs
90
91
train_ds.set_transform(ds_transforms)
92
eval_ds.set_transform(ds_transforms)
93
94
model_name = checkpoint.split("/")[1]
95
96
# this setup requires about 40GB of gpu memory
97
training_args = TrainingArguments(
98
output_dir=f"{model_name}-pokemon",
99
learning_rate=5e-6,
100
num_train_epochs=10,
101
bf16=True,
102
per_device_train_batch_size=32,
103
per_device_eval_batch_size=32,
104
gradient_accumulation_steps=2,
105
dataloader_pin_memory=False,
106
save_total_limit=3,
107
evaluation_strategy="steps",
108
save_strategy="steps",
109
save_steps=1000, # don't save until ready...
110
eval_steps=40,
111
logging_steps=40,
112
remove_unused_columns=False,
113
push_to_hub=False,
114
label_names=["labels"],
115
load_best_model_at_end=True,
116
report_to=None,
117
)
118
119
trainer = Trainer(
120
model=model,
121
args=training_args,
122
train_dataset=train_ds,
123
eval_dataset=eval_ds,
124
)
125
126
trainer.train()
127
128
# check generation again after finetuning
129
check_inference()
130
131
# after finetuning ideally we want generate to produce something like: a drawing of a pink and blue pokemon
132
133