CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/paligemma/Fine_tune_PaliGemma.ipynb
Views: 2542
Kernel: Python 3

PaliGemma Fine-tuning

In this notebook, we will fine-tune pretrained PaliGemma on a small split of VQAv2 dataset. Let's get started by installing necessary libraries.

!pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate

We will authenticate to access the model using notebook_login().

from huggingface_hub import notebook_login notebook_login()

Let's load the dataset.

from datasets import load_dataset ds = load_dataset('HuggingFaceM4/VQAv2', split="train[:10%]")
cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"] ds = ds.remove_columns(cols_remove)
split_ds = ds.train_test_split(test_size=0.05) # we'll use a very small split for demo train_ds = split_ds["test"]
train_ds

Load the processor to preprocess the dataset.

from transformers import PaliGemmaProcessor model_id = "google/paligemma-3b-pt-224" processor = PaliGemmaProcessor.from_pretrained(model_id)

We will preprocess our examples. We need to prepare a prompt template and pass the text input inside, pass it with batches of images to processor. Then we will set the pad tokens and image tokens to -100 to let the model ignore them. We will pass our preprocessed input as labels to make the model learn how to generate responses.

import torch device = "cuda" image_token = processor.tokenizer.convert_tokens_to_ids("<image>") def collate_fn(examples): texts = ["answer " + example["question"] for example in examples] labels= [example['multiple_choice_answer'] for example in examples] images = [example["image"].convert("RGB") for example in examples] tokens = processor(text=texts, images=images, suffix=labels, return_tensors="pt", padding="longest", tokenize_newline_separately=False) tokens = tokens.to(torch.bfloat16).to(device) return tokens

Our dataset is a very general one and similar to many datasets that PaliGemma was trained with. In this case, we do not need to fine-tune the image encoder, the multimodal projector but we will only fine-tune the text decoder.

from transformers import PaliGemmaForConditionalGeneration import torch model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) for param in model.vision_tower.parameters(): param.requires_grad = False for param in model.multi_modal_projector.parameters(): param.requires_grad = False

Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized.

from transformers import BitsAndBytesConfig from peft import get_peft_model, LoraConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_type=torch.bfloat16 ) lora_config = LoraConfig( r=8, target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", ) model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}) model = get_peft_model(model, lora_config) model.print_trainable_parameters() #trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344

We will now initialize the TrainingArguments.

from transformers import TrainingArguments args=TrainingArguments( num_train_epochs=2, remove_unused_columns=False, per_device_train_batch_size=4, gradient_accumulation_steps=4, warmup_steps=2, learning_rate=2e-5, weight_decay=1e-6, adam_beta2=0.999, logging_steps=100, optim="adamw_hf", save_strategy="steps", save_steps=1000, push_to_hub=True, save_total_limit=1, output_dir="paligemma_vqav2", bf16=True, report_to=["tensorboard"], dataloader_pin_memory=False )

We can now start training.

from transformers import Trainer trainer = Trainer( model=model, train_dataset=train_ds , data_collator=collate_fn, args=args )
trainer.train()
trainer.push_to_hub()

You can find steps to infer here.