Path: blob/master/examples/vision/ipynb/oxford_pets_image_segmentation.ipynb
3236 views
Kernel: Python 3
Image segmentation with a U-Net-like architecture
Author: fchollet
Date created: 2019/03/20
Last modified: 2020/04/20
Description: Image segmentation model trained from scratch on the Oxford Pets dataset.
Download the data
In [0]:
!!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz !!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz ! !curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz !curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz ! !tar -xf images.tar.gz !tar -xf annotations.tar.gz
Prepare paths of input images and target segmentation masks
In [0]:
import os input_dir = "images/" target_dir = "annotations/trimaps/" img_size = (160, 160) num_classes = 3 batch_size = 32 input_img_paths = sorted( [ os.path.join(input_dir, fname) for fname in os.listdir(input_dir) if fname.endswith(".jpg") ] ) target_img_paths = sorted( [ os.path.join(target_dir, fname) for fname in os.listdir(target_dir) if fname.endswith(".png") and not fname.startswith(".") ] ) print("Number of samples:", len(input_img_paths)) for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]): print(input_path, "|", target_path)
What does one input image and corresponding segmentation mask look like?
In [0]:
from IPython.display import Image, display from keras.utils import load_img from PIL import ImageOps # Display input image #7 display(Image(filename=input_img_paths[9])) # Display auto-contrast version of corresponding target (per-pixel categories) img = ImageOps.autocontrast(load_img(target_img_paths[9])) display(img)
Prepare dataset to load & vectorize batches of data
In [0]:
import keras import numpy as np from tensorflow import data as tf_data from tensorflow import image as tf_image from tensorflow import io as tf_io def get_dataset( batch_size, img_size, input_img_paths, target_img_paths, max_dataset_len=None, ): """Returns a TF Dataset.""" def load_img_masks(input_img_path, target_img_path): input_img = tf_io.read_file(input_img_path) input_img = tf_io.decode_png(input_img, channels=3) input_img = tf_image.resize(input_img, img_size) input_img = tf_image.convert_image_dtype(input_img, "float32") target_img = tf_io.read_file(target_img_path) target_img = tf_io.decode_png(target_img, channels=1) target_img = tf_image.resize(target_img, img_size, method="nearest") target_img = tf_image.convert_image_dtype(target_img, "uint8") # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2: target_img -= 1 return input_img, target_img # For faster debugging, limit the size of data if max_dataset_len: input_img_paths = input_img_paths[:max_dataset_len] target_img_paths = target_img_paths[:max_dataset_len] dataset = tf_data.Dataset.from_tensor_slices((input_img_paths, target_img_paths)) dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE) return dataset.batch(batch_size)
Prepare U-Net Xception-style model
In [0]:
from keras import layers def get_model(img_size, num_classes): inputs = keras.Input(shape=img_size + (3,)) ### [First half of the network: downsampling inputs] ### # Entry block x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) previous_block_activation = x # Set aside residual # Blocks 1, 2, 3 are identical apart from the feature depth. for filters in [64, 128, 256]: x = layers.Activation("relu")(x) x = layers.SeparableConv2D(filters, 3, padding="same")(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) x = layers.SeparableConv2D(filters, 3, padding="same")(x) x = layers.BatchNormalization()(x) x = layers.MaxPooling2D(3, strides=2, padding="same")(x) # Project residual residual = layers.Conv2D(filters, 1, strides=2, padding="same")( previous_block_activation ) x = layers.add([x, residual]) # Add back residual previous_block_activation = x # Set aside next residual ### [Second half of the network: upsampling inputs] ### for filters in [256, 128, 64, 32]: x = layers.Activation("relu")(x) x = layers.Conv2DTranspose(filters, 3, padding="same")(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) x = layers.Conv2DTranspose(filters, 3, padding="same")(x) x = layers.BatchNormalization()(x) x = layers.UpSampling2D(2)(x) # Project residual residual = layers.UpSampling2D(2)(previous_block_activation) residual = layers.Conv2D(filters, 1, padding="same")(residual) x = layers.add([x, residual]) # Add back residual previous_block_activation = x # Set aside next residual # Add a per-pixel classification layer outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x) # Define the model model = keras.Model(inputs, outputs) return model # Build model model = get_model(img_size, num_classes) model.summary()
Set aside a validation split
In [0]:
import random # Split our img paths into a training and a validation set val_samples = 1000 random.Random(1337).shuffle(input_img_paths) random.Random(1337).shuffle(target_img_paths) train_input_img_paths = input_img_paths[:-val_samples] train_target_img_paths = target_img_paths[:-val_samples] val_input_img_paths = input_img_paths[-val_samples:] val_target_img_paths = target_img_paths[-val_samples:] # Instantiate dataset for each split # Limit input files in `max_dataset_len` for faster epoch training time. # Remove the `max_dataset_len` arg when running with full dataset. train_dataset = get_dataset( batch_size, img_size, train_input_img_paths, train_target_img_paths, max_dataset_len=1000, ) valid_dataset = get_dataset( batch_size, img_size, val_input_img_paths, val_target_img_paths )
Train the model
In [0]:
# Configure the model for training. # We use the "sparse" version of categorical_crossentropy # because our target data is integers. model.compile( optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy" ) callbacks = [ keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True) ] # Train the model, doing validation at the end of each epoch. epochs = 50 model.fit( train_dataset, epochs=epochs, validation_data=valid_dataset, callbacks=callbacks, verbose=2, )
Visualize predictions
In [0]:
# Generate predictions for all images in the validation set val_dataset = get_dataset( batch_size, img_size, val_input_img_paths, val_target_img_paths ) val_preds = model.predict(val_dataset) def display_mask(i): """Quick utility to display a model's prediction.""" mask = np.argmax(val_preds[i], axis=-1) mask = np.expand_dims(mask, axis=-1) img = ImageOps.autocontrast(keras.utils.array_to_img(mask)) display(img) # Display results for validation image #10 i = 10 # Display input image display(Image(filename=val_input_img_paths[i])) # Display ground-truth target mask img = ImageOps.autocontrast(load_img(val_target_img_paths[i])) display(img) # Display mask predicted by our model display_mask(i) # Note that the model only sees inputs at 150x150.