CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/image_classification-tf.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

Fine-tuning for Image Classification with 🤗 Transformers

This notebook shows how to fine-tune any pretrained Vision model for Image Classification on a custom dataset. The idea is to add a randomly initialized classification head on top of a pre-trained encoder, and fine-tune the model altogether on a labeled dataset.

ImageFolder

This notebook leverages the ImageFolder feature to easily run the notebook on a custom dataset (namely, EuroSAT in this tutorial). You can either load a Dataset from local folders or from local/remote files, like zip or tar.

Any model

This notebook is built to run on any image classification dataset with any vision model checkpoint from the Model Hub as long as that model has a TensorFlow version with a Image Classification head, such as:

Data augmentation

This notebook leverages TensorFlow's image module for applying data augmentation. Alternative notebooks which leverage other libraries such as Albumentations to come!


Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.

In this notebook, we'll fine-tune from the https://huggingface.co/microsoft/swin-tiny-patch4-window7-224 checkpoint, but note that there are many, many more available on the hub.

model_checkpoint = "microsoft/swin-tiny-patch4-window7-224" # pre-trained model from which to fine-tune batch_size = 32 # batch size for training and evaluation

Before we start, let's install the datasets and transformers libraries.

!pip install -q datasets transformers
WARNING: Error parsing requirements for setuptools: [Errno 2] No such file or directory: '/usr/local/lib/python3.8/dist-packages/setuptools-60.6.0.dist-info/METADATA' WARNING: You are using pip version 22.0.2; however, version 22.1.2 is available. You should consider upgrading via the '/home/amy/tenv/bin/python3 -m pip install --upgrade pip' command.

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your token:

from huggingface_hub import notebook_login notebook_login()

Then you need to install Git-LFS to upload your model checkpoints:

%%capture !git lfs install !git config --global credential.helper store

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("image_classification_notebook", framework="tensorflow")

Fine-tuning a model on an image classification task

In this notebook, we will see how to fine-tune one of the 🤗 Transformers vision models on an Image Classification dataset.

Given an image, the goal is to predict an appropriate class for it, like "tiger". The screenshot below is taken from a ViT fine-tuned on ImageNet-1k - try out the inference widget!

drawing

Loading the dataset

We will use the 🤗 Datasets library's ImageFolder feature to download our custom dataset into a DatasetDict.

In this case, the EuroSAT dataset is hosted remotely, so we provide the data_files argument. Alternatively, if you have local folders with images, you can load them using the data_dir argument.

from datasets import load_dataset # load a custom dataset from local/remote files or folders using the ImageFolder feature # option 1: local/remote files (supporting the following formats: tar, gzip, zip, xz, rar, zstd) dataset = load_dataset("imagefolder", data_files="https://madm.dfki.de/files/sentinel/EuroSAT.zip") # note that you can also provide several splits: # dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]}) # note that you can push your dataset to the hub very easily (and reload afterwards using load_dataset)! # dataset.push_to_hub("nielsr/eurosat") # dataset.push_to_hub("nielsr/eurosat", private=True) # dataset = load_dataset("nielsr/eurosat") # option 2: local folder # dataset = load_dataset("imagefolder", data_dir="path_to_folder") # option 3: just load any existing dataset from the hub, like CIFAR-10, FashionMNIST ... # dataset = load_dataset("cifar10")
Using custom data configuration default-71c8f80171b09b95 Reusing dataset imagefolder (/home/amy/.cache/huggingface/datasets/imagefolder/default-71c8f80171b09b95/0.0.0/d83791becf4369e27b95a98928b7c95213f7901edd8134aaac48cb591b8b21f0)

Let us also load the Accuracy metric, which we'll use to evaluate our model both during and after training.

from datasets import load_metric metric = load_metric("accuracy")

The dataset object itself is a DatasetDict, which contains one key per split (in this case, only "train" for a training split).

dataset
DatasetDict({ train: Dataset({ features: ['image', 'label'], num_rows: 27000 }) })

To access an actual element, you need to select a split first, then give an index:

example = dataset["train"][10] example
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64 at 0x7F343552BF10>, 'label': 5}

Each example consists of an image and a corresponding label. We can also verify this by checking the features of the dataset:

dataset["train"].features
{'image': Image(decode=True, id=None), 'label': ClassLabel(num_classes=10, names=['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake'], id=None)}

The cool thing is that we can directly view the image (as the 'image' field is an Image feature), as follows:

example['image']
Image in a Jupyter notebook

Let's make it a little bigger as the images in the EuroSAT dataset are of low resolution (64x64 pixels):

example['image'].resize((200, 200))
Image in a Jupyter notebook

Let's print the corresponding label:

example['label']
5

As you can see, the label field is not an actual string label. By default the ClassLabel fields are encoded into integers for convenience:

dataset["train"].features["label"]
ClassLabel(num_classes=10, names=['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake'], id=None)

Let's create an id2label dictionary to decode them back to strings and see what they are. The inverse label2id will be useful too, when we load the model later.

labels = dataset["train"].features["label"].names label2id, id2label = dict(), dict() for i, label in enumerate(labels): label2id[label] = i id2label[i] = label id2label[5]
'Pasture'

Preprocessing the data

Before we can feed these images to our model, we need to preprocess them.

Preprocessing images typically comes down to (1) resizing them to a particular size (2) normalizing the color channels (R,G,B) using a mean and standard deviation. These are referred to as image transformations.

In addition, one typically performs what is called data augmentation during training (like random cropping and flipping) to make the model more robust and achieve higher accuracy. Data augmentation is also a great technique to increase the size of the training data.

We will use tf.image for the image transformations/data augmentation in this tutorial, but note that one can use any other package (like albumentations, imgaug, etc.).

To make sure we (1) resize to the appropriate size (2) use the appropriate image mean and standard deviation for the model architecture we are going to use, we instantiate what is called a feature extractor with the AutoFeatureExtractor.from_pretrained method.

This feature extractor is a minimal preprocessor that can be used to prepare images for inference.

from transformers import AutoFeatureExtractor feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) feature_extractor
ViTFeatureExtractor { "do_normalize": true, "do_resize": true, "feature_extractor_type": "ViTFeatureExtractor", "image_mean": [ 0.485, 0.456, 0.406 ], "image_std": [ 0.229, 0.224, 0.225 ], "resample": 3, "size": 224 }

The Datasets library is made for processing data very easily. We can write custom functions, which can then be applied on an entire dataset (either using .map() or .set_transform()).

Here we define 2 separate functions, one for training (which includes data augmentation) and one for validation (which only includes resizing, center cropping and normalizing).

import numpy as np import tensorflow as tf from keras import backend def normalize_img(img, mean, std): mean = tf.constant(mean) std = tf.constant(std) return (img - mean) / tf.maximum(std, backend.epsilon()) def get_resize_shape(img, size): if isinstance(size, tuple): return size height, width, _ = img.shape target_height = int(size * height / width) if height > width else size target_width = int(size * width / height) if width > height else size return (target_height, target_width) def get_random_crop_size(img, scale=(0.08, 1.0), ratio=(3/4, 4/3)): height, width, channels = img.shape img_ratio = width / height crop_log_ratio = np.random.uniform(*np.log(ratio), size=1) crop_ratio = np.exp(crop_log_ratio) crop_scale = np.random.uniform(*scale, size=1) # Make sure the longest side is within the image size if crop_ratio < img_ratio: crop_height = int(height * crop_scale) crop_width = int(crop_height * crop_ratio) else: crop_width = int(width * crop_scale) crop_height = int(crop_width / crop_ratio) return (crop_height, crop_width, channels) def train_transforms(image): image = tf.keras.utils.img_to_array(image) # Note - this augmentation isn't exactly the same in the PyTorch # examples in: # https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb # as there isn't a direct RandomResizedCrop equivalent. # A custom transform would need to be implemented for this. # Randomly select the crop size based on a possible scale # and ratio range crop_size = get_random_crop_size(image) image = tf.image.random_crop(image, size=crop_size) image = tf.image.resize( image, size=(feature_extractor.size, feature_extractor.size), method=tf.image.ResizeMethod.BILINEAR ) image = tf.image.random_flip_left_right(image) image /= 255 image = normalize_img( image, mean=feature_extractor.image_mean, std=feature_extractor.image_std ) # All image models take channels first format: BCHW image = tf.transpose(image, (2, 0, 1)) return image def val_transforms(image): image = tf.keras.utils.img_to_array(image) resize_shape = get_resize_shape(image, feature_extractor.size) image = tf.image.resize( image, size=resize_shape, method=tf.image.ResizeMethod.BILINEAR ) image = tf.image.crop_to_bounding_box( image, offset_height=image.shape[0] // 2 - feature_extractor.size // 2, offset_width=image.shape[1] // 2 - feature_extractor.size // 2, target_height=feature_extractor.size, target_width=feature_extractor.size, ) image /= 255 image = normalize_img( image, mean=feature_extractor.image_mean, std=feature_extractor.image_std ) # All image models take channels first format: BCHW image = tf.transpose(image, (2, 0, 1)) return image def preprocess_train(example_batch): """Apply train_transforms across a batch.""" example_batch['pixel_values'] = [ train_transforms(image.convert("RGB")) for image in example_batch["image"] ] return example_batch def preprocess_val(example_batch): """Apply val_transforms across a batch.""" example_batch['pixel_values'] = [ val_transforms(image.convert("RGB")) for image in example_batch["image"] ] return example_batch

Let's quickly visualise some example outputs from our processing pipeline.

Part of the processing pipeline rescales them between [0, 1] and normalizes them. This results in pixel values having negative values. To easily visualise and compare the original images and augmentations we undo this normalization and rescaling here.

import matplotlib.pyplot as plt %matplotlib inline
def unnormalize_img(img, mean, std): img = (img * std) + mean return img def process_for_plotting(img): img = img.numpy() img = img.transpose(1, 2, 0) img = unnormalize_img( img=img, mean=feature_extractor.image_mean, std=feature_extractor.image_std ) img = img * 255 img = img.astype(int) return img n = 10 fig, ax = plt.subplots(2, n, figsize=(20, 10)) for i in range(n): orig_img = dataset['train'][i]['image'] proc_img = train_transforms(orig_img) orig_img = np.array(orig_img.convert("RGB")) # In order to plot and easy compare the images, # we denormalise and rescale here so that pixel values # are between [0, 255] and reorder to be HWC proc_img = process_for_plotting(proc_img) ax[0][i].imshow(orig_img) ax[1][i].imshow(proc_img) ax[0][i].axis('off') ax[1][i].axis('off')
Image in a Jupyter notebook

Next, we can preprocess our dataset by applying these functions. We will use the set_transform functionality, which allows to apply the functions above on-the-fly (meaning that they will only be applied when the images are loaded in RAM).

# split up training into training + validation splits = dataset["train"].train_test_split(test_size=0.1) train_ds = splits['train'] val_ds = splits['test']
train_ds.set_transform(preprocess_train) val_ds.set_transform(preprocess_val)
train_ds
Dataset({ features: ['image', 'label'], num_rows: 24300 })

Let's access an element to see that we've added a "pixel_values" feature:

train_ds[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64 at 0x7F33E0357760>, 'label': 3, 'pixel_values': <tf.Tensor: shape=(3, 224, 224), dtype=float32, numpy= array([[[-0.9705454 , -0.9705454 , -0.9705454 , ..., -0.7479236 , -0.7479236 , -0.7479236 ], [-0.9705454 , -0.9705454 , -0.9705454 , ..., -0.7479236 , -0.7479236 , -0.7479236 ], [-0.9705454 , -0.9705454 , -0.9705454 , ..., -0.7479236 , -0.7479236 , -0.7479236 ], ..., [-1.1075436 , -1.1075436 , -1.1075436 , ..., -1.073294 , -1.073294 , -1.073294 ], [-1.1075436 , -1.1075436 , -1.1075436 , ..., -1.073294 , -1.073294 , -1.073294 ], [-1.1075436 , -1.1075436 , -1.1075436 , ..., -1.073294 , -1.073294 , -1.073294 ]], [[-0.582633 , -0.582633 , -0.582633 , ..., -0.495098 , -0.495098 , -0.495098 ], [-0.582633 , -0.582633 , -0.582633 , ..., -0.495098 , -0.495098 , -0.495098 ], [-0.582633 , -0.582633 , -0.582633 , ..., -0.495098 , -0.495098 , -0.495098 ], ..., [-0.687675 , -0.687675 , -0.687675 , ..., -0.687675 , -0.687675 , -0.687675 ], [-0.687675 , -0.687675 , -0.687675 , ..., -0.687675 , -0.687675 , -0.687675 ], [-0.687675 , -0.687675 , -0.687675 , ..., -0.687675 , -0.687675 , -0.687675 ]], [[-0.07895417, -0.07895417, -0.07895417, ..., -0.07895417, -0.07895417, -0.07895417], [-0.07895417, -0.07895417, -0.07895417, ..., -0.07895417, -0.07895417, -0.07895417], [-0.07895417, -0.07895417, -0.07895417, ..., -0.07895417, -0.07895417, -0.07895417], ..., [-0.21838771, -0.21838771, -0.21838771, ..., -0.20095852, -0.20095852, -0.20095852], [-0.21838771, -0.21838771, -0.21838771, ..., -0.20095852, -0.20095852, -0.20095852], [-0.21838771, -0.21838771, -0.21838771, ..., -0.20095852, -0.20095852, -0.20095852]]], dtype=float32)>}

Training the model

Now that our data is ready, we can download the pretrained model and fine-tune it. For classification we use the TFAutoModelForImageClassification class. Calling the from_pretrained method on it will download and cache the weights for us. As the label ids and the number of labels are dataset dependent, we pass label2id, and id2label alongside the model_checkpoint here. This will make sure a custom classification head will be created (with a custom number of output neurons).

NOTE: in case you're planning to fine-tune an already fine-tuned checkpoint, like facebook/convnext-tiny-224 (which has already been fine-tuned on ImageNet-1k), then you need to provide the additional argument ignore_mismatched_sizes=True to the from_pretrained method. This will make sure the output head (with 1000 output neurons) is thrown away and replaced by a new, randomly initialized classification head that includes a custom number of output neurons. You don't need to specify this argument in case the pre-trained model doesn't include a head.

from transformers import TFAutoModelForImageClassification model = TFAutoModelForImageClassification.from_pretrained( model_checkpoint, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint )
All model checkpoint layers were used when initializing TFSwinForImageClassification. Some weights of TFSwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match: - classifier/kernel:0: found shape (768, 1000) in the checkpoint and (768, 10) in the model instantiated - classifier/bias:0: found shape (1000,) in the checkpoint and (10,) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The warning is telling us we are throwing away some weights (the weights and bias of the classifier layer) and randomly initializing some other (the weights and bias of a new classifier layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

learning_rate = 5e-5 weight_decay = 0.01 epochs = 3

Now we initialize our optimizer.

from transformers import AdamWeightDecay optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay)

Note that most models on the Hub compute loss internally, so we actually don't have to specify anything there! Leaving the loss field blank will cause the model to read the loss head as its loss value.

This is an unusual quirk of TensorFlow models in 🤗 Transformers, so it's worth elaborating on in a little more detail. All 🤗 Transformers models are capable of computing an appropriate loss for their task internally (for example, a CausalLM model will use a cross-entropy loss). To do this, the labels must be provided in the input dict (or equivalently, in the columns argument to to_tf_dataset()), so that they are visible to the model during the forward pass.

This is quite different from the standard Keras way of handling losses, where labels are passed separately and not visible to the main body of the model, and loss is handled by a function that the user passes to compile(), which uses the model outputs and the label to compute a loss value.

The approach we take is that if the user does not pass a loss to compile(), the model will assume you want the internal loss. If you are doing this, you should make sure that the labels column(s) are included in the input dict or in the columns argument to to_tf_dataset.

If you want to use your own loss, that is of course possible too! If you do this, you should make sure your labels column(s) are passed like normal labels, either as the second argument to model.fit(), or in the label_cols argument to to_tf_dataset.

model.compile(optimizer=optimizer)
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.

We need to convert our datasets to a format Keras understands. The easiest way to do this is with the to_tf_dataset() method. Note that our data collators are designed to work for multiple frameworks, so ensure you set the return_tensors='np' argument to get NumPy arrays out - you don't want to accidentally get a load of torch.Tensor objects in the middle of your nice TF code! You could also use return_tensors='tf' to get TensorFlow tensors, but our to_tf_dataset pipeline actually uses a NumPy loader internally, which is wrapped at the end with a tf.data.Dataset. As a result, np is usually more reliable and performant when you're using it!

from transformers import DefaultDataCollator
data_collator = DefaultDataCollator(return_tensors="np") train_set = train_ds.to_tf_dataset( columns=["pixel_values", "label"], shuffle=True, batch_size=batch_size, collate_fn=data_collator ) val_set = val_ds.to_tf_dataset( columns=["pixel_values", "label"], shuffle=False, batch_size=batch_size, collate_fn=data_collator )

train_set is now a tf.data.Dataset type object. We see that it contains two elements - labels and pixel_values (but not image) as a result of the preprocessing done in preprocess_train.

train_set
<PrefetchDataset element_spec={'labels': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'pixel_values': TensorSpec(shape=(None, 3, 224, 224), dtype=tf.float32, name=None)}>
batch = next(iter(train_set)) batch
{'labels': <tf.Tensor: shape=(32,), dtype=int64, numpy= array([9, 0, 6, 6, 2, 7, 5, 8, 9, 8, 5, 6, 2, 3, 1, 2, 1, 4, 6, 3, 3, 3, 9, 5, 5, 4, 7, 3, 9, 2, 4, 7])>, 'pixel_values': <tf.Tensor: shape=(32, 3, 224, 224), dtype=float32, numpy= array([[[[-1.3130406 , -1.3130406 , -1.3130406 , ..., -1.2274169 , -1.2274169 , -1.2274169 ], [-1.3130406 , -1.3130406 , -1.3130406 , ..., -1.2274169 , -1.2274169 , -1.2274169 ], [-1.3130406 , -1.3130406 , -1.3130406 , ..., -1.2274169 , -1.2274169 , -1.2274169 ], ..., [-1.2445416 , -1.2445416 , -1.2445416 , ..., -1.3301654 , -1.3301654 , -1.3301654 ], [-1.2445416 , -1.2445416 , -1.2445416 , ..., -1.3301654 , -1.3301654 , -1.3301654 ], [-1.2445416 , -1.2445416 , -1.2445416 , ..., -1.3301654 , -1.3301654 , -1.3301654 ]], [[-0.757703 , -0.757703 , -0.757703 , ..., -0.705182 , -0.705182 , -0.705182 ], [-0.757703 , -0.757703 , -0.757703 , ..., -0.705182 , -0.705182 , -0.705182 ], [-0.757703 , -0.757703 , -0.757703 , ..., -0.705182 , -0.705182 , -0.705182 ], ..., [-0.722689 , -0.722689 , -0.722689 , ..., -0.757703 , -0.757703 , -0.757703 ], [-0.722689 , -0.722689 , -0.722689 , ..., -0.757703 , -0.757703 , -0.757703 ], [-0.722689 , -0.722689 , -0.722689 , ..., -0.757703 , -0.757703 , -0.757703 ]], [[-0.30553368, -0.30553368, -0.30553368, ..., -0.2706753 , -0.2706753 , -0.2706753 ], [-0.30553368, -0.30553368, -0.30553368, ..., -0.2706753 , -0.2706753 , -0.2706753 ], [-0.30553368, -0.30553368, -0.30553368, ..., -0.2706753 , -0.2706753 , -0.2706753 ], ..., [-0.28810447, -0.28810447, -0.28810447, ..., -0.34039208, -0.34039208, -0.34039208], [-0.28810447, -0.28810447, -0.28810447, ..., -0.34039208, -0.34039208, -0.34039208], [-0.28810447, -0.28810447, -0.28810447, ..., -0.34039208, -0.34039208, -0.34039208]]], [[[-1.1246684 , -1.1246684 , -1.1398053 , ..., -0.7922647 , -0.78217316, -0.78217316], [-1.1246684 , -1.1246684 , -1.1398053 , ..., -0.7922647 , -0.78217316, -0.78217316], [-1.118858 , -1.118858 , -1.1348511 , ..., -0.7848334 , -0.77345794, -0.77345794], ..., [-1.0128986 , -1.0128986 , -0.93421745, ..., 0.08415603, 0.07406463, 0.07406463], [-1.0390445 , -1.0390445 , -0.9633593 , ..., 0.08415603, 0.07406463, 0.07406463], [-1.0390445 , -1.0390445 , -0.9633593 , ..., 0.08415603, 0.07406463, 0.07406463]], [[-0.617647 , -0.617647 , -0.6305429 , ..., -0.32776552, -0.32002798, -0.32002798], [-0.617647 , -0.617647 , -0.6305429 , ..., -0.32776552, -0.32002798, -0.32002798], [-0.61170715, -0.61170715, -0.6250405 , ..., -0.32567072, -0.31705812, -0.31705812], ..., [-0.6962723 , -0.6962723 , -0.65150535, ..., -0.17739274, -0.17997196, -0.17997196], [-0.705182 , -0.705182 , -0.66391563, ..., -0.17739274, -0.17997196, -0.17997196], [-0.705182 , -0.705182 , -0.66391563, ..., -0.17739274, -0.17997196, -0.17997196]], [[-0.2532461 , -0.2532461 , -0.26094913, ..., -0.08152194, -0.07895417, -0.07895417], [-0.2532461 , -0.2532461 , -0.26094913, ..., -0.08152194, -0.07895417, -0.07895417], [-0.2532461 , -0.2532461 , -0.26182044, ..., -0.0828286 , -0.07895417, -0.07895417], ..., [-0.33152196, -0.33152196, -0.29509267, ..., 0.09834104, 0.09533776, 0.09533776], [-0.34039208, -0.34039208, -0.30701214, ..., 0.09790553, 0.09533776, 0.09533776], [-0.34039208, -0.34039208, -0.30701214, ..., 0.09790553, 0.09533776, 0.09533776]]], [[[-0.19993155, -0.19993155, -0.19993155, ..., 0.5193082 , 0.5193082 , 0.5193082 ], [-0.19993155, -0.19993155, -0.19993155, ..., 0.5193082 , 0.5193082 , 0.5193082 ], [-0.19993155, -0.19993155, -0.19993155, ..., 0.5193082 , 0.5193082 , 0.5193082 ], ..., [ 0.7761795 , 0.7761795 , 0.7761795 , ..., 0.15968838, 0.15968838, 0.15968838], [ 0.7761795 , 0.7761795 , 0.7761795 , ..., 0.15968838, 0.15968838, 0.15968838], [ 0.7761795 , 0.7761795 , 0.7761795 , ..., 0.15968838, 0.15968838, 0.15968838]], [[-0.05742295, -0.05742295, -0.05742295, ..., 0.62535024, 0.62535024, 0.62535024], [-0.05742295, -0.05742295, -0.05742295, ..., 0.62535024, 0.62535024, 0.62535024], [-0.05742295, -0.05742295, -0.05742295, ..., 0.62535024, 0.62535024, 0.62535024], ..., [ 0.78291327, 0.78291327, 0.78291327, ..., 0.3802522 , 0.3802522 , 0.3802522 ], [ 0.78291327, 0.78291327, 0.78291327, ..., 0.3802522 , 0.3802522 , 0.3802522 ], [ 0.78291327, 0.78291327, 0.78291327, ..., 0.3802522 , 0.3802522 , 0.3802522 ]], [[ 0.23477131, 0.23477131, 0.23477131, ..., 0.6356429 , 0.6356429 , 0.6356429 ], [ 0.23477131, 0.23477131, 0.23477131, ..., 0.6356429 , 0.6356429 , 0.6356429 ], [ 0.23477131, 0.23477131, 0.23477131, ..., 0.6356429 , 0.6356429 , 0.6356429 ], ..., [ 0.8099348 , 0.8099348 , 0.8099348 , ..., 0.47878015, 0.47878015, 0.47878015], [ 0.8099348 , 0.8099348 , 0.8099348 , ..., 0.47878015, 0.47878015, 0.47878015], [ 0.8099348 , 0.8099348 , 0.8099348 , ..., 0.47878015, 0.47878015, 0.47878015]]], ..., [[[-0.25130582, -0.25130582, -0.25130582, ..., -0.45680285, -0.45680285, -0.45680285], [-0.25130582, -0.25130582, -0.25130582, ..., -0.45680285, -0.45680285, -0.45680285], [-0.25130582, -0.25130582, -0.25130582, ..., -0.45680285, -0.45680285, -0.45680285], ..., [ 0.09118938, 0.09118938, 0.09118938, ..., -0.37117907, -0.37117907, -0.37117907], [ 0.09118938, 0.09118938, 0.09118938, ..., -0.37117907, -0.37117907, -0.37117907], [ 0.09118938, 0.09118938, 0.09118938, ..., -0.37117907, -0.37117907, -0.37117907]], [[-0.12745096, -0.12745096, -0.12745096, ..., -0.24999997, -0.24999997, -0.24999997], [-0.12745096, -0.12745096, -0.12745096, ..., -0.24999997, -0.24999997, -0.24999997], [-0.12745096, -0.12745096, -0.12745096, ..., -0.24999997, -0.24999997, -0.24999997], ..., [ 0.04761905, 0.04761905, 0.04761905, ..., -0.39005598, -0.39005598, -0.39005598], [ 0.04761905, 0.04761905, 0.04761905, ..., -0.39005598, -0.39005598, -0.39005598], [ 0.04761905, 0.04761905, 0.04761905, ..., -0.39005598, -0.39005598, -0.39005598]], [[ 0.09533776, 0.09533776, 0.09533776, ..., -0.09638336, -0.09638336, -0.09638336], [ 0.09533776, 0.09533776, 0.09533776, ..., -0.09638336, -0.09638336, -0.09638336], [ 0.09533776, 0.09533776, 0.09533776, ..., -0.09638336, -0.09638336, -0.09638336], ..., [ 0.11276696, 0.11276696, 0.11276696, ..., -0.20095852, -0.20095852, -0.20095852], [ 0.11276696, 0.11276696, 0.11276696, ..., -0.20095852, -0.20095852, -0.20095852], [ 0.11276696, 0.11276696, 0.11276696, ..., -0.20095852, -0.20095852, -0.20095852]]], [[[-1.0904187 , -1.0904187 , -1.0904187 , ..., -1.1589178 , -1.1589178 , -1.1589178 ], [-1.0904187 , -1.0904187 , -1.0904187 , ..., -1.1589178 , -1.1589178 , -1.1589178 ], [-1.0904187 , -1.0904187 , -1.0904187 , ..., -1.1589178 , -1.1589178 , -1.1589178 ], ..., [-0.23418105, -0.23418105, -0.23418105, ..., 2.2489083 , 2.2489083 , 2.2489083 ], [-0.23418105, -0.23418105, -0.23418105, ..., 2.2489083 , 2.2489083 , 2.2489083 ], [-0.23418105, -0.23418105, -0.23418105, ..., 2.2489083 , 2.2489083 , 2.2489083 ]], [[-0.460084 , -0.460084 , -0.460084 , ..., -0.42507 , -0.42507 , -0.42507 ], [-0.460084 , -0.460084 , -0.460084 , ..., -0.42507 , -0.42507 , -0.42507 ], [-0.460084 , -0.460084 , -0.460084 , ..., -0.42507 , -0.42507 , -0.42507 ], ..., [ 0.06512605, 0.06512605, 0.06512605, ..., 2.4285715 , 2.4285715 , 2.4285715 ], [ 0.06512605, 0.06512605, 0.06512605, ..., 2.4285715 , 2.4285715 , 2.4285715 ], [ 0.06512605, 0.06512605, 0.06512605, ..., 2.4285715 , 2.4285715 , 2.4285715 ]], [[-0.2706753 , -0.2706753 , -0.2706753 , ..., -0.23581691, -0.23581691, -0.23581691], [-0.2706753 , -0.2706753 , -0.2706753 , ..., -0.23581691, -0.23581691, -0.23581691], [-0.2706753 , -0.2706753 , -0.2706753 , ..., -0.23581691, -0.23581691, -0.23581691], ..., [ 0.2696297 , 0.2696297 , 0.2696297 , ..., 2.64 , 2.64 , 2.64 ], [ 0.2696297 , 0.2696297 , 0.2696297 , ..., 2.64 , 2.64 , 2.64 ], [ 0.2696297 , 0.2696297 , 0.2696297 , ..., 2.64 , 2.64 , 2.64 ]]], [[[-0.06293353, -0.06293353, -0.06293353, ..., 0.57068247, 0.57068247, 0.57068247], [-0.06293353, -0.06293353, -0.06293353, ..., 0.57068247, 0.57068247, 0.57068247], [-0.06293353, -0.06293353, -0.06293353, ..., 0.57068247, 0.57068247, 0.57068247], ..., [-0.18280679, -0.18280679, -0.18280679, ..., 0.34806067, 0.34806067, 0.34806067], [-0.18280679, -0.18280679, -0.18280679, ..., 0.34806067, 0.34806067, 0.34806067], [-0.18280679, -0.18280679, -0.18280679, ..., 0.34806067, 0.34806067, 0.34806067]], [[-0.02240895, -0.02240895, -0.02240895, ..., 0.62535024, 0.62535024, 0.62535024], [-0.02240895, -0.02240895, -0.02240895, ..., 0.62535024, 0.62535024, 0.62535024], [-0.02240895, -0.02240895, -0.02240895, ..., 0.62535024, 0.62535024, 0.62535024], ..., [-0.02240895, -0.02240895, -0.02240895, ..., 0.4677872 , 0.4677872 , 0.4677872 ], [-0.02240895, -0.02240895, -0.02240895, ..., 0.4677872 , 0.4677872 , 0.4677872 ], [-0.02240895, -0.02240895, -0.02240895, ..., 0.4677872 , 0.4677872 , 0.4677872 ]], [[ 0.32191727, 0.32191727, 0.32191727, ..., 0.91451 , 0.91451 , 0.91451 ], [ 0.32191727, 0.32191727, 0.32191727, ..., 0.91451 , 0.91451 , 0.91451 ], [ 0.32191727, 0.32191727, 0.32191727, ..., 0.91451 , 0.91451 , 0.91451 ], ..., [ 0.42649257, 0.42649257, 0.42649257, ..., 0.827364 , 0.827364 , 0.827364 ], [ 0.42649257, 0.42649257, 0.42649257, ..., 0.827364 , 0.827364 , 0.827364 ], [ 0.42649257, 0.42649257, 0.42649257, ..., 0.827364 , 0.827364 , 0.827364 ]]]], dtype=float32)>}

The last thing to define is how to compute the metrics from the predictions. We need to define a function for this, which will just use the metric we loaded earlier. The only preprocessing we have to do is to take the argmax of our predicted logits.

In addition, let's wrap this metric computation function in a KerasMetricCallback. This callback will compute the metric on the validation set each epoch, including printing it and logging it for other callbacks like TensorBoard and EarlyStopping.

Why do it this way, though, and not just use a straightforward Keras Metric object? This is a good question - on this task, metrics such as Accuracy are very straightforward, and it would probably make more sense to just use a Keras metric for those instead. However, we want to demonstrate the use of KerasMetricCallback here, because it can handle any arbitrary Python function for the metric computation.

Wow do we actually use KerasMetricCallback? We simply define a function that computes metrics given a tuple of numpy arrays of predictions and labels, then we pass that, along with the validation set to compute metrics on, to the callback:

import numpy as np from transformers.keras_callbacks import KerasMetricCallback # the compute_metrics function takes a Tuple as input: # first element is the logits of the model as Numpy arrays, # second element is the ground-truth labels as Numpy arrays. def compute_metrics(eval_predictions): predictions = np.argmax(eval_predictions[0], axis=1) metric_val = metric.compute(predictions=predictions, references=eval_predictions[1]) return {"val_" + k: v for k, v in metric_val.items()} metric_callback = KerasMetricCallback( metric_fn=compute_metrics, eval_dataset=val_set, batch_size=batch_size, label_cols=['labels'] )

Now we can train our model. We can also add a callback to sync up our model with the Hub - this allows us to resume training from other machines and even test the model's inference quality midway through training! Make sure to change the username if you do. If you don't want to do this, simply remove the callbacks argument in the call to fit().

from transformers.keras_callbacks import PushToHubCallback from tensorflow.keras.callbacks import TensorBoard tensorboard_callback = TensorBoard(log_dir="./ic_from_scratch_model_save/logs") model_name = model_checkpoint.split("/")[-1] push_to_hub_model_id = f"{model_name}-finetuned-eurosat" push_to_hub_callback = PushToHubCallback( output_dir="./ic_from_scratch_model_save", hub_model_id=push_to_hub_model_id, tokenizer=feature_extractor )
Cloning https://huggingface.co/amyeroberts/swin-tiny-patch4-window7-224-finetuned-eurosat into local empty directory.
callbacks = [metric_callback, tensorboard_callback, push_to_hub_callback]
model.fit( train_set, validation_data=val_set, callbacks=callbacks, epochs=epochs, batch_size=batch_size, )
Epoch 1/3 1/1 [==============================] - 3s 3s/steps - loss: 0.65 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 248ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 47ms/step 1/1 [==============================] - 0s 245ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 248ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 251ms/step 1/1 [==============================] - 0s 240ms/step 1/1 [==============================] - 0s 46ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 113ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 109ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 113ms/step 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 208ms/step 1/1 [==============================] - 0s 251ms/step 1/1 [==============================] - 0s 50ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 252ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 107ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 104ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 151ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 49ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 3s 3s/step 760/760 [==============================] - 616s 786ms/step - loss: 0.6595 - val_loss: 0.1077 - val_accuracy: 0.9622 Epoch 2/3 1/1 [==============================] - 0s 44ms/step- loss: 0.45 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 251ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 258ms/step 1/1 [==============================] - 0s 50ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 250ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 104ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 107ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 253ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 248ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 241ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 63ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 110ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 125ms/step 760/760 [==============================] - 586s 771ms/step - loss: 0.4503 - val_loss: 0.0813 - val_accuracy: 0.9707 Epoch 3/3 1/1 [==============================] - 0s 34ms/step- loss: 0.41 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 103ms/step 1/1 [==============================] - 0s 109ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 102ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 106ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 245ms/step 1/1 [==============================] - 0s 251ms/step 1/1 [==============================] - 0s 47ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 245ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 252ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 23ms/step 1/1 [==============================] - 0s 136ms/step 760/760 [==============================] - 584s 769ms/step - loss: 0.4117 - val_loss: 0.0491 - val_accuracy: 0.9822
Several commits (2) will be pushed upstream. The progress bars may be unreliable.
To https://huggingface.co/amyeroberts/swin-tiny-patch4-window7-224-finetuned-eurosat b5cdccd..94f8012 main -> main
<keras.callbacks.History at 0x7f3384135910>

Once the training is completed, we can evaluate our model and get its loss on the validation set like this:

eval_loss = model.evaluate(val_set) eval_loss
85/85 [==============================] - 24s 286ms/step - loss: 0.0491
0.04911531135439873

Alternatively, we could also get the predictions from the model, and calculate metrics using the datasets.Metric object.

for batch in iter(val_set): predictions = model.predict(batch) predicted_labels = np.argmax(predictions.logits, -1) metric.add_batch(predictions=predicted_labels, references=batch['labels']) metric.compute()
1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 109ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 257ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 104ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 40ms/step 1/1 [==============================] - 0s 43ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 247ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 47ms/step 1/1 [==============================] - 0s 249ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 245ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 241ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 244ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 246ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 243ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 105ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 242ms/step 1/1 [==============================] - 0s 38ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 42ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 130ms/step
{'accuracy': 0.9822222222222222}

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier "your-username/the-name-you-picked" so for instance:

from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor feature_extractor = TFAutoFeatureExtractor.from_pretrained("amyeroberts/my-awesome-model") model = TFAutoModelForImageClassification.from_pretrained("amyeroberts/my-awesome-model")

Inference

Let's say you have a new image, on which you'd like to make a prediction. Let's load a satellite image of a forest (that's not part of the EuroSAT dataset), and see how the model does.

from PIL import Image import requests url = 'https://huggingface.co/nielsr/convnext-tiny-finetuned-eurostat/resolve/main/forest.png' image = Image.open(requests.get(url, stream=True).raw) image
Image in a Jupyter notebook

We'll load the feature extractor and model from the hub (here, we use the Auto Classes, which will make sure the appropriate classes will be loaded automatically based on the config.json and preprocessor_config.json files of the repo on the hub):

from transformers import TFAutoModelForImageClassification, AutoFeatureExtractor repo_name = "amyeroberts/swin-tiny-patch4-window7-224-finetuned-eurosat" feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name) model = TFAutoModelForImageClassification.from_pretrained(repo_name)
All model checkpoint layers were used when initializing TFSwinForImageClassification. All the layers of TFSwinForImageClassification were initialized from the model checkpoint at amyeroberts/swin-tiny-patch4-window7-224-finetuned-eurosat. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFSwinForImageClassification for predictions without further training.
# prepare image for the model encoding = feature_extractor(image.convert("RGB"), return_tensors="tf") print(encoding.pixel_values.shape)
(1, 3, 224, 224)
outputs = model(encoding) logits = outputs.logits
predicted_class_idx = tf.math.argmax(logits, -1).numpy()[0] print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: Forest

Looks like our model got it correct!

Pipeline API

An alternative way to quickly perform inference with any model on the hub is by leveraging the Pipeline API, which abstracts away all the steps we did manually above for us. It will perform the preprocessing, forward pass and postprocessing all in a single object.

Note the configuration for feature_extractor will be pulled from the specified repo and used to build the feature_extractor in this pipeline.

Let's showcase this for our trained model:

from transformers import pipeline pipe = pipeline( "image-classification", "amyeroberts/swin-tiny-patch4-window7-224-finetuned-eurosat", framework="tf" )
All model checkpoint layers were used when initializing TFSwinForImageClassification. All the layers of TFSwinForImageClassification were initialized from the model checkpoint at amyeroberts/swin-tiny-patch4-window7-224-finetuned-eurosat. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFSwinForImageClassification for predictions without further training.
pipe.feature_extractor
ViTFeatureExtractor { "do_normalize": true, "do_resize": true, "feature_extractor_type": "ViTFeatureExtractor", "image_mean": [ 0.485, 0.456, 0.406 ], "image_std": [ 0.229, 0.224, 0.225 ], "resample": 3, "size": 224 }
pipe(image)
[{'score': 0.6285105347633362, 'label': 'Forest'}, {'score': 0.24625204503536224, 'label': 'HerbaceousVegetation'}, {'score': 0.10910254716873169, 'label': 'Residential'}, {'score': 0.005121162626892328, 'label': 'PermanentCrop'}, {'score': 0.004805787466466427, 'label': 'Highway'}]

As we can see, it does not only show the class label with the highest probability, but does return the top 5 labels, with their corresponding scores. Note that the pipelines also work with local models and feature extractors:

pipe = pipeline("image-classification", model=model, feature_extractor=feature_extractor)
pipe(image)
[{'score': 0.6285105347633362, 'label': 'Forest'}, {'score': 0.24625204503536224, 'label': 'HerbaceousVegetation'}, {'score': 0.10910254716873169, 'label': 'Residential'}, {'score': 0.005121162626892328, 'label': 'PermanentCrop'}, {'score': 0.004805787466466427, 'label': 'Highway'}]