Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/examples/idefics/idefics_zero3_finetuning/idefics_zero3_finetuning.py
Views: 2548
"""1On one node, launch with `deepspeed --num_gpus N idefics_zero3_finetuning.py`2by replacing N with the number of your GPUs34For several nodes, using Slurm, a template script is provided at5`examples/idefics/idefics_zero3_finetuning/slurm_script_idefics_zero3_finetuning_multinode.slurm`67For more information, follow the tutorial on using DeepSpeed with Transformers at8https://huggingface.co/docs/transformers/main_classes/deepspeed9"""1011import torch12import torchvision.transforms as transforms13from datasets import load_dataset14from PIL import Image15from transformers import AutoProcessor, IdeficsForVisionText2Text, Trainer, TrainingArguments161718device = "cuda" if torch.cuda.is_available() else "cpu"1920checkpoint = "HuggingFaceM4/idefics-9b"2122processor = AutoProcessor.from_pretrained(checkpoint, use_auth_token=True)232425def convert_to_rgb(image):26# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background27# for transparent images. The call to `alpha_composite` handles this case28if image.mode == "RGB":29return image30image_rgba = image.convert("RGBA")31background = Image.new("RGBA", image_rgba.size, (255, 255, 255))32alpha_composite = Image.alpha_composite(background, image_rgba)33alpha_composite = alpha_composite.convert("RGB")34return alpha_composite353637def ds_transforms(example_batch):38image_size = processor.image_processor.image_size39image_mean = processor.image_processor.image_mean40image_std = processor.image_processor.image_std41image_transform = transforms.Compose(42[43convert_to_rgb,44transforms.RandomResizedCrop(45(image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC46),47transforms.ToTensor(),48transforms.Normalize(mean=image_mean, std=image_std),49]50)51prompts = []52for i in range(len(example_batch["caption"])):53# We split the captions to avoid having very long examples, which would require more GPU ram during training54caption = example_batch["caption"][i].split(".")[0]55try:56# There are a handful of images that are not hosted anymore. This is a small (dummy) hack to skip these57processor.image_processor.fetch_images(example_batch["image_url"][i])58except Exception:59print(60"Warning: at least one image couldn't be retrieved from the internet in an example. Skipping the"61" batch."62)63prompts.append(64[65example_batch["image_url"][i],66f"Question: What's on the picture? Answer: This is {example_batch['name'][i]}. {caption}</s>",67],68)69inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)70inputs["labels"] = inputs["input_ids"]71return inputs727374# load and prepare dataset75ds = load_dataset("TheFusion21/PokemonCards")76ds = ds["train"].train_test_split(test_size=0.002)77train_ds = ds["train"]78eval_ds = ds["test"]79train_ds.set_transform(ds_transforms)80eval_ds.set_transform(ds_transforms)818283# Important, define the training_args before the model84ds_config = {85"communication_data_type": "fp32",86"bf16": {"enabled": True},87"zero_optimization": {88"stage": 3,89"overlap_comm": False,90"reduce_bucket_size": "auto",91"contiguous_gradients": True,92"stage3_gather_16bit_weights_on_model_save": False,93"stage3_prefetch_bucket_size": "auto",94"stage3_param_persistence_threshold": "auto",95"stage3_max_live_parameters": 2e9,96"stage3_max_reuse_distance": 2e9,97"offload_optimizer": {"device": "none"},98"offload_param": {"device": "none"},99},100"gradient_clipping": "auto",101"train_batch_size": "auto",102"train_micro_batch_size_per_gpu": "auto",103"steps_per_print": 2000000,104}105training_args = TrainingArguments(106output_dir="idefics-pokemon",107learning_rate=2e-4,108bf16=True,109per_device_train_batch_size=1,110per_device_eval_batch_size=1,111gradient_accumulation_steps=1,112# gradient_checkpointing=True, # Uncomment if OOM113dataloader_pin_memory=False,114save_total_limit=3,115evaluation_strategy="steps",116save_strategy="steps",117save_steps=40,118eval_steps=20,119logging_steps=20,120max_steps=40,121remove_unused_columns=False,122push_to_hub=False,123label_names=["labels"],124load_best_model_at_end=True,125report_to="none",126optim="adamw_torch",127deepspeed=ds_config,128)129130model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)131132trainer = Trainer(133model=model,134args=training_args,135train_dataset=train_ds,136eval_dataset=eval_ds,137)138139result = trainer.train()140print(result) # Prints one per process - mostly here for sanity check141142143