Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/examples/image_captioning_pix2struct.ipynb
Views: 2535
Fine-tune Pix2Struct using Hugging Face transformers
and datasets
🤗
This tutorial is largely based from the GiT tutorial on how to fine-tune GiT on a custom image captioning dataset. Here we will use a dummy dataset of football players ⚽ that is uploaded on the Hub. The images have been manually selected together with the captions. Check the 🤗 documentation on how to create and upload your own image-text dataset.
Model overview
In this tutorial, we will load an architecture called Pix2Struct recently released by Google and made them available on 🤗 Hub! This architecture differs from other models from its pretraining procedure and the way the model extract patches from the image by using the aspect-ratio preserving patch extraction method.
The release came with no more than 20 checkpoints!
As each checkpoint has been finetuned on specific domain, let's finetune our own Pix2Struct to our target domain: Football players! For that we will use the google/pix2struct-base
which corresponds to a general usecase model that you can load to fine-tune your model.
Set-up environment
Run the cells below to setup the environment
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.6/7.6 MB 75.6 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.2/199.2 KB 19.3 MB/s eta 0:00:00
Building wheel for transformers (pyproject.toml) ... done
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 469.0/469.0 KB 20.1 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 63.8 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 132.9/132.9 KB 17.4 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.2/212.2 KB 23.2 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 110.5/110.5 KB 12.9 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 264.6/264.6 KB 24.1 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.2/199.2 KB 18.9 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 114.2/114.2 KB 14.9 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 158.8/158.8 KB 19.9 MB/s eta 0:00:00
Load the image captioning dataset
Let's load the image captioning dataset, you just need few lines of code for that. The dataset only consists of 6 images that we have manually labeled for the sake of the tutorial.
Let's retrieve the caption of the first example:
And the corresponding image
Create PyTorch Dataset
Understanding max_patches
argument
The paper introduces a new paradigm for processing the input image. It takes the image and create n_patches
aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get max_patches
patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.
For the sake of our example, we will fine-tune a model with max_patches=1024
.
Note that most of the -base
models have been fine-tuned with max_patches=2048
, and 4096
for -large
models.
Load model and processor
Now that we have loaded the processor, let's load the dataset and the dataloader:
Train the model
Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using AdamW
optimizer with lr=1e-5
seemed to be the best approach.
Let's also print the generation output of the model each 20 epochs!
Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour.
Inference
Let's check the results on our train dataset
Load from the Hub
Once trained you can push the model and processor on the Hub to use them later. Meanwhile you can play with the model that we have fine-tuned!
Let's check the results on our train dataset!