Path: blob/master/examples/vision/ipynb/3D_image_classification.ipynb
3236 views
3D image classification from CT scans
Author: Hasib Zunair
Date created: 2020/09/23
Last modified: 2024/01/11
Description: Train a 3D convolutional neural network to predict presence of pneumonia.
Introduction
This example will show the steps needed to build a 3D convolutional neural network (CNN) to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan), 3D CNNs are a powerful model for learning representations for volumetric data.
References
Setup
Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings
In this example, we use a subset of the MosMedData: Chest CT Scans with COVID-19 Related Findings. This dataset consists of lung CT scans with COVID-19 related findings, as well as without such findings.
We will be using the associated radiological findings of the CT scans as labels to build a classifier to predict presence of viral pneumonia. Hence, the task is a binary classification problem.
Loading data and preprocessing
The files are provided in Nifti format with the extension .nii. To read the scans, we use the nibabel
package. You can install the package via pip install nibabel
. CT scans store raw voxel intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset. Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold between -1000 and 400 is commonly used to normalize CT scans.
To process the data, we do the following:
We first rotate the volumes by 90 degrees, so the orientation is fixed
We scale the HU values to be between 0 and 1.
We resize width, height and depth.
Here we define several helper functions to process the data. These functions will be used when building training and validation datasets.
Let's read the paths of the CT scans from the class directories.
Build train and validation datasets
Read the scans from the class directories and assign labels. Downsample the scans to have shape of 128x128x64. Rescale the raw HU values to the range 0 to 1. Lastly, split the dataset into train and validation subsets.
Data augmentation
The CT scans also augmented by rotating at random angles during training. Since the data is stored in rank-3 tensors of shape (samples, height, width, depth)
, we add a dimension of size 1 at axis 4 to be able to perform 3D convolutions on the data. The new shape is thus (samples, height, width, depth, 1)
. There are different kinds of preprocessing and augmentation techniques out there, this example shows a few simple ones to get started.
While defining the train and validation data loader, the training data is passed through and augmentation function which randomly rotates volume at different angles. Note that both training and validation data are already rescaled to have values between 0 and 1.
Visualize an augmented CT scan.
Since a CT scan has many slices, let's visualize a montage of the slices.
Define a 3D convolutional neural network
To make the model easier to understand, we structure it into blocks. The architecture of the 3D CNN used in this example is based on this paper.
Train model
It is important to note that the number of samples is very small (only 200) and we don't specify a random seed. As such, you can expect significant variance in the results. The full dataset which consists of over 1000 CT scans can be found here. Using the full dataset, an accuracy of 83% was achieved. A variability of 6-7% in the classification performance is observed in both cases.
Visualizing model performance
Here the model accuracy and loss for the training and the validation sets are plotted. Since the validation set is class-balanced, accuracy provides an unbiased representation of the model's performance.