CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/image_captioning_blip.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

Fine-tune BLIP using Hugging Face transformers and datasets 🤗

This tutorial is largely based from the GiT tutorial on how to fine-tune GiT on a custom image captioning dataset. Here we will use a dummy dataset of football players ⚽ that is uploaded on the Hub. The images have been manually selected together with the captions. Check the 🤗 documentation on how to create and upload your own image-text dataset.

Set-up environment

!pip install git+https://github.com/huggingface/transformers.git@main
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting git+https://github.com/younesbelkada/transformers.git@blip-train-support Cloning https://github.com/younesbelkada/transformers.git (to revision blip-train-support) to /tmp/pip-req-build-gn4s6x3z Running command git clone --filter=blob:none --quiet https://github.com/younesbelkada/transformers.git /tmp/pip-req-build-gn4s6x3z Running command git checkout -b blip-train-support --track origin/blip-train-support Switched to a new branch 'blip-train-support' Branch 'blip-train-support' set up to track remote branch 'blip-train-support' from 'origin'. Resolved https://github.com/younesbelkada/transformers.git to commit 27c6b9889d900fb87dfdec225b0d8e7c4fa09937 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (3.9.0) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (6.0) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (4.64.1) Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (0.11.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (2022.6.2) Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (0.13.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (21.3) Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (2.25.1) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers==4.26.0.dev0) (1.21.6) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.10.0->transformers==4.26.0.dev0) (4.4.0) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers==4.26.0.dev0) (3.0.9) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.26.0.dev0) (1.26.14) Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.26.0.dev0) (4.0.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.26.0.dev0) (2022.12.7) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==4.26.0.dev0) (2.10)
!pip install -q datasets

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("image_captioning_blip_notebook", framework="pytorch")

Load the image captioning dataset

Let's load the image captioning dataset, you just need few lines of code for that.

from datasets import load_dataset dataset = load_dataset("ybelkada/football-dataset", split="train")
WARNING:datasets.builder:Using custom data configuration ybelkada--football-dataset-1ad065f8e9005a29 WARNING:datasets.builder:Found cached dataset parquet (/root/.cache/huggingface/datasets/ybelkada___parquet/ybelkada--football-dataset-1ad065f8e9005a29/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)

Let's retrieve the caption of the first example:

dataset[0]["text"]
"Benzema after Real Mardid's win against PSG"

And the corresponding image

dataset[0]["image"]
Image in a Jupyter notebook

Create PyTorch Dataset

The lines below are entirely copied from the original notebook!

from torch.utils.data import Dataset, DataLoader class ImageCaptioningDataset(Dataset): def __init__(self, dataset, processor): self.dataset = dataset self.processor = processor def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt") # remove batch dimension encoding = {k:v.squeeze() for k,v in encoding.items()} return encoding

Load model and processor

from transformers import AutoProcessor, BlipForConditionalGeneration processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

Now that we have loaded the processor, let's load the dataset and the dataloader:

train_dataset = ImageCaptioningDataset(dataset, processor) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

Train the model

Let's train the model! Run the simply the cell below for training the model

import torch optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.train() for epoch in range(50): print("Epoch:", epoch) for idx, batch in enumerate(train_dataloader): input_ids = batch.pop("input_ids").to(device) pixel_values = batch.pop("pixel_values").to(device) outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids) loss = outputs.loss print("Loss:", loss.item()) loss.backward() optimizer.step() optimizer.zero_grad()
Epoch: 0 Loss: 13.106168746948242 Loss: 10.644421577453613 Loss: 9.593768119812012 Epoch: 1 Loss: 9.306917190551758 Loss: 9.081585884094238 Loss: 8.899713516235352 Epoch: 2 Loss: 8.757176399230957 Loss: 8.57335090637207 Loss: 8.46764087677002 Epoch: 3 Loss: 8.328741073608398 Loss: 8.201028823852539 Loss: 8.095505714416504 Epoch: 4 Loss: 7.967352867126465 Loss: 7.861735820770264 Loss: 7.732804298400879 Epoch: 5 Loss: 7.630571365356445 Loss: 7.519181251525879 Loss: 7.405021667480469 Epoch: 6 Loss: 7.284258842468262 Loss: 7.187586784362793 Loss: 7.060364723205566 Epoch: 7 Loss: 6.954672813415527 Loss: 6.846510410308838 Loss: 6.6976189613342285 Epoch: 8 Loss: 6.587822437286377 Loss: 6.486807346343994 Loss: 6.362427711486816 Epoch: 9 Loss: 6.233264923095703 Loss: 6.120891571044922 Loss: 5.994716644287109 Epoch: 10 Loss: 5.855278968811035 Loss: 5.752918243408203 Loss: 5.645371437072754 Epoch: 11 Loss: 5.505440711975098 Loss: 5.391564846038818 Loss: 5.268132209777832 Epoch: 12 Loss: 5.152494430541992 Loss: 5.031476020812988 Loss: 4.887501239776611 Epoch: 13 Loss: 4.778462886810303 Loss: 4.660128116607666 Loss: 4.538983345031738 Epoch: 14 Loss: 4.43063497543335 Loss: 4.298006534576416 Loss: 4.165637493133545 Epoch: 15 Loss: 4.069493293762207 Loss: 3.929532051086426 Loss: 3.8233513832092285 Epoch: 16 Loss: 3.697223663330078 Loss: 3.595960855484009 Loss: 3.4820494651794434 Epoch: 17 Loss: 3.357645273208618 Loss: 3.2525794506073 Loss: 3.1521053314208984 Epoch: 18 Loss: 3.0371406078338623 Loss: 2.933652877807617 Loss: 2.831568956375122 Epoch: 19 Loss: 2.737278938293457 Loss: 2.636235237121582 Loss: 2.5387978553771973 Epoch: 20 Loss: 2.452430248260498 Loss: 2.3689417839050293 Loss: 2.289254665374756 Epoch: 21 Loss: 2.207612991333008 Loss: 2.140639066696167 Loss: 2.069406032562256 Epoch: 22 Loss: 2.0035407543182373 Loss: 1.9466133117675781 Loss: 1.8937859535217285 Epoch: 23 Loss: 1.8416728973388672 Loss: 1.795449137687683 Loss: 1.7535300254821777 Epoch: 24 Loss: 1.7158992290496826 Loss: 1.6809648275375366 Loss: 1.6494619846343994 Epoch: 25 Loss: 1.6215711832046509 Loss: 1.5973800420761108 Loss: 1.5743063688278198 Epoch: 26 Loss: 1.5543696880340576 Loss: 1.5368084907531738 Loss: 1.520519495010376 Epoch: 27 Loss: 1.5065045356750488 Loss: 1.4936732053756714 Loss: 1.4824585914611816 Epoch: 28 Loss: 1.4722621440887451 Loss: 1.4632645845413208 Loss: 1.455365777015686 Epoch: 29 Loss: 1.4480940103530884 Loss: 1.4414517879486084 Loss: 1.4355690479278564 Epoch: 30 Loss: 1.4303795099258423 Loss: 1.4255009889602661 Loss: 1.4215054512023926 Epoch: 31 Loss: 1.417230248451233 Loss: 1.413886308670044 Loss: 1.410994052886963 Epoch: 32 Loss: 1.4078407287597656 Loss: 1.405259609222412 Loss: 1.402454137802124 Epoch: 33 Loss: 1.4003974199295044 Loss: 1.398270606994629 Loss: 1.396573543548584 Epoch: 34 Loss: 1.3947657346725464 Loss: 1.392857551574707 Loss: 1.3917793035507202 Epoch: 35 Loss: 1.390182375907898 Loss: 1.388831377029419 Loss: 1.3878085613250732 Epoch: 36 Loss: 1.3866865634918213 Loss: 1.3857353925704956 Loss: 1.384238839149475 Epoch: 37 Loss: 1.3838305473327637 Loss: 1.38273024559021 Loss: 1.3817709684371948 Epoch: 38 Loss: 1.3813902139663696 Loss: 1.3802759647369385 Loss: 1.3797214031219482 Epoch: 39 Loss: 1.3791683912277222 Loss: 1.3785974979400635 Loss: 1.377758264541626 Epoch: 40 Loss: 1.377307653427124 Loss: 1.3770530223846436 Loss: 1.3761452436447144 Epoch: 41 Loss: 1.3757835626602173 Loss: 1.3754386901855469 Loss: 1.3749370574951172 Epoch: 42 Loss: 1.374186396598816 Loss: 1.3744440078735352 Loss: 1.3737257719039917 Epoch: 43 Loss: 1.3736841678619385 Loss: 1.372819423675537 Loss: 1.3724960088729858 Epoch: 44 Loss: 1.3724427223205566 Loss: 1.3719404935836792 Loss: 1.371631383895874 Epoch: 45 Loss: 1.3712401390075684 Loss: 1.3714425563812256 Loss: 1.3706567287445068 Epoch: 46 Loss: 1.3702927827835083 Loss: 1.3706389665603638 Loss: 1.3699930906295776 Epoch: 47 Loss: 1.370144248008728 Loss: 1.3693041801452637 Loss: 1.3692874908447266 Epoch: 48 Loss: 1.3689672946929932 Loss: 1.3692435026168823 Loss: 1.3685294389724731 Epoch: 49 Loss: 1.3687036037445068 Loss: 1.3680329322814941 Loss: 1.3681769371032715

Inference

Let's check the results on our train dataset

# load image example = dataset[0] image = example["image"] image
Image in a Jupyter notebook
# prepare image for the model inputs = processor(images=image, return_tensors="pt").to(device) pixel_values = inputs.pixel_values generated_ids = model.generate(pixel_values=pixel_values, max_length=50) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_caption)
benzema after real mardid's win against psg

Load from the Hub

Once trained you can push the model and processor on the Hub to use them later. Meanwhile you can play with the model that we have fine-tuned!

from transformers import BlipForConditionalGeneration, AutoProcessor model = BlipForConditionalGeneration.from_pretrained("ybelkada/blip-image-captioning-base-football-finetuned").to(device) processor = AutoProcessor.from_pretrained("ybelkada/blip-image-captioning-base-football-finetuned")

Let's check the results on our train dataset!

from matplotlib import pyplot as plt fig = plt.figure(figsize=(18, 14)) # prepare image for the model for i, example in enumerate(dataset): image = example["image"] inputs = processor(images=image, return_tensors="pt").to(device) pixel_values = inputs.pixel_values generated_ids = model.generate(pixel_values=pixel_values, max_length=50) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] fig.add_subplot(2, 3, i+1) plt.imshow(image) plt.axis("off") plt.title(f"Generated caption: {generated_caption}")
Image in a Jupyter notebook