Path: blob/master/examples/vision/ipynb/basnet_segmentation.ipynb
3236 views
Highly accurate boundaries segmentation using BASNet
Author: Hamid Ali
Date created: 2023/05/30
Last modified: 2025/01/24
Description: Boundaries aware segmentation model trained on the DUTS dataset.
Introduction
Deep semantic segmentation algorithms have improved a lot recently, but still fails to correctly predict pixels around object boundaries. In this example we implement Boundary-Aware Segmentation Network (BASNet), using two stage predict and refine architecture, and a hybrid loss it can predict highly accurate boundaries and fine structures for image segmentation.
References:
Download the Data
We will use the DUTS-TE dataset for training. It has 5,019 images but we will use 140 for training and validation to save notebook running time. DUTS is relatively large salient object segmentation dataset. which contain diversified textures and structures common to real-world images in both foreground and background.
Define Hyperparameters
Create PyDataset
s
We will use load_paths()
to load and split 140 paths into train and validation set, and convert paths into PyDataset
object.
Visualize Data
Analyze Mask
Lets print unique values of above displayed mask. You can see despite belonging to one class, it's intensity is changing between low(0) to high(255). This variation in intensity makes it hard for network to generate good segmentation map for salient or camouflaged object segmentation. Because of its Residual Refined Module (RMs), BASNet is good in generating highly accurate boundaries and fine structures.
Building the BASNet Model
BASNet comprises of a predict-refine architecture and a hybrid loss. The predict-refine architecture consists of a densely supervised encoder-decoder network and a residual refinement module, which are respectively used to predict and refine a segmentation probability map.
Prediction Module
Prediction module is a heavy encoder decoder structure like U-Net. The encoder includes an input convolutional layer and six stages. First four are adopted from ResNet-34 and rest are basic res-blocks. Since first convolution and pooling layer of ResNet-34 is skipped so we will use get_resnet_block()
to extract first four blocks. Both bridge and decoder uses three convolutional layers with side outputs. The module produces seven segmentation probability maps during training, with the last one considered the final output.
Residual Refinement Module
Refinement Modules (RMs), designed as a residual block aim to refines the coarse(blurry and noisy boundaries) segmentation maps generated by prediction module. Similar to prediction module it's also an encode decoder structure but with light weight 4 stages, each containing one convolutional block()
init. At the end it adds both coarse and residual output to generate refined output.
Combine Predict and Refinement Module
Hybrid Loss
Another important feature of BASNet is its hybrid loss function, which is a combination of binary cross entropy, structural similarity and intersection-over-union losses, which guide the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
Train the Model
Visualize Predictions
In paper BASNet was trained on DUTS-TR dataset, which has 10553 images. Model was trained for 400k iterations with a batch size of eight and without a validation dataset. After training model was evaluated on DUTS-TE dataset and achieved a mean absolute error of 0.042
.
Since BASNet is a deep model and cannot be trained in a short amount of time which is a requirement for keras example notebook, so we will load pretrained weights from here to show model prediction. Due to computer power limitation this model was trained for 120k iterations but it still demonstrates its capabilities. For further details about trainings parameters please check given link.