Path: blob/master/examples/vision/ipynb/brain_tumor_segmentation.ipynb
8266 views
3D Multimodal Brain Tumor Segmentation
Author: Mohammed Innat
Date created: 2026/02/02
Last modified: 2026/02/02
Description: Implementing 3D semantic segmentation pipeline for medical imaging.
Brain tumor segmentation is a core task in medical image analysis, where the goal is to automatically identify and label different tumor sub-regions from 3D MRI scans. Accurate segmentation helps clinicians with diagnosis, treatment planning, and disease monitoring. In this tutorial, we focus on multimodal MRI-based brain tumor segmentation using the widely adopted BraTS (Brain Tumor Segmentation) dataset.
The BraTS Dataset
The BraTS dataset provides multimodal 3D brain MRI scans, released as NIfTI files (.nii.gz). For each patient, four MRI modalities are available:
T1 – native T1-weighted MRI
T1Gd – post-contrast T1-weighted MRI
T2 – T2-weighted MRI
T2-FLAIR – Fluid Attenuated Inversion Recovery MRI
These scans are collected using different scanners and clinical protocols from 19 institutions, making the dataset diverse and realistic. More details about the dataset can be found in the official BraTS documentation.
Segmentation Labels
Each scan is manually annotated by one to four expert raters, following a standardized annotation protocol and reviewed by experienced neuroradiologists. The segmentation masks contain the following tumor sub-regions:
NCR / NET (label 1) – Necrotic and non-enhancing tumor core
ED (label 2) – Peritumoral edema
ET (label 4) – GD-enhancing tumor
0 – Background (non-tumor tissue)
The data are released after preprocessing:
All modalities are co-registered
Resampled to
1 mm³isotropic resolutionSkull-stripped for consistency
Dataset Format and TFRecord Conversion
The original BraTS scans are provided in .nii format and can be accessed from Kaggle here. To enable efficient training pipelines, we convert the NIfTI files into TFRecord format:
The conversion process is documented here
The preprocessed TFRecord dataset is available here
Each TFRecord file contains up to 20 cases
Since BraTS does not provide publicly available ground-truth labels for validation or test sets, we will hold out a subset of TFRecord files from training for validation purposes.
What This Tutorial Covers
In this tutorial, we provide a step-by-step, end-to-end workflow for brain tumor segmentation using medicai, a Keras-based medical imaging library with multi-backend support. We will walk through:
Loading the Dataset
Read TFRecord files that contain
image,label, andaffinematrix information.Build efficient data pipelines using the
tf.dataAPI for training and evaluation.
Medical Image Preprocessing
Apply image transformations provided by
medicaito prepare the data for model input.
Model Building
Loss and Metrics Definition
Using Dice-based loss functions and segmentation metrics tailored for medical imaging
Model Evaluation
Performing inference on large 3D volumes using sliding window inference
Computing per-class evaluation metrics
Visualization of Results
Visualizing predicted segmentation masks for qualitative analysis
By the end of this tutorial, you will have a complete brain tumor segmentation pipeline, from data loading and preprocessing to model training, evaluation, and visualization using modern 3D deep learning techniques and the medicai framework.
Download the dataset from kaggle.
Imports
Create Multi-label Brain Tumor Labels
The BraTS segmentation task involves multiple tumor sub-regions, and it is formulated as a multi-label segmentation problem. The label combinations are used to define the following clinical regions of interest:
These region-wise groupings allow for evaluation across different tumor structures relevant for clinical assessment and treatment planning. A sample view is shown below, figure taken from BraTS-benchmark paper.

Managing Multi-Label Outputs with TensorBundle
To organize and manage these multi-label segmentation targets, we will implement a custom transformation using TensorBundle from medicai. The TensorBundle is a lightweight container class designed to hold:
A dictionary of tensors (e.g., images, labels)
Optional metadata associated with those tensors (e.g., affine matrices, spacing, original shapes)
This design allows data and metadata to be passed together through the transformation pipeline in a structured and consistent way. Each medicai transformation expects inputs to be organized as key:value pairs, for example:
Using TensorBundle makes it easier to apply complex medical imaging transformations while preserving spatial and anatomical information throughout preprocessing and model inference.
Transformation
Each medicai transformation expects the input to have the shape (depth, height, width, channel). The original .nii (and converted .tfrecord) format contains the input shape of (height, width, depth). To make it compatible with medicai, we need to re-arrange the shape axes.
Each transformation class of medicai expects input as either a dictionary or a TensorBundle object, as discussed earlier. When a dictionary of input data (along with metadata) is passed, it is automatically wrapped into a TensorBundle instance. The examples below demonstrate how transformations are used in this way.
The tfrecord parser
Dataloader
The training batch size can be set to more than 1 depending on the environment and available resources. However, we intentionally keep the validation batch size as 1 to handle variable-sized samples more flexibly. While padded or ragged batches are alternative options, a batch size of 1 ensures simplicity and consistency during evaluation, especially for 3D medical data.
sanity check: Fetch a single validation sample to inspect its shape and values.
sanity check: Visualize the middle slice of the image and its corresponding label.
sanity check: Visualize sample image and label channels at middle slice index.
Model
We will be using the 3D model architecture Swin UNEt TRansformers, i.e., SwinUNETR. It was used in the BraTS 2021 segmentation challenge by NVIDIA. The model was among the top-performing methods. It uses a Swin Transformer encoder to extract features at five different resolutions. A CNN-based decoder is connected to each resolution using skip connections.
The BraTS dataset provides four input modalities: flair, t1, t1ce, and t2 and three multi-label outputs: tumor-core, whole-tumor, and enhancing-tumor. Accordingly, we will initiate the model with 4 input channels and 3 output channels.

Callback
We will be using sliding window inference callback from medicai to perform validation at certain interval or epoch during training. Based on the number of epoch size, we should set interval accordingly. For example, if epoch is set 15 and we want to evaluate model on validation set every 5 epoch, then we should set interval to 5.
Training
Set more epoch for better optimization.
Let’s take a quick look at how our model performed during training. We will first print the available metrics recorded in the training history, save them to a CSV file for future reference, and then visualize them to better understand the model’s learning progress over epochs.
Evaluation
In this Kaggle notebook (version 5), we trained the model on the entire dataset for approximately 30 epochs. The resulting weights will be used for further evaluation. Note that the validation set used in both here and Kaggle notebook are the same: training_shard_36.tfrec, which contains 8 samples.
In this section, we perform sliding window inference on the validation dataset and compute Dice scores for overall segmentation quality as well as specific tumor subregions:
Tumor Core (TC)
Whole Tumor (WT)
Enhancing Tumor (ET)
Due to the variable size, and larger size of the validation data, we iterate over the validation dataloader. The sliding window inference handles input patches and computes the predictions for each batch.
Analyse and Visualize
Let's analyse the model predictions and visualize them. First, we will implement the test transformation pipeline. This is same as validation transformation.
Let's load the tfrecord file and check its properties.
Run the transformation to prepare the inputs.
Pass the preprocessed sample to the inference object, ensuring that a batch axis is added to the input beforehand.
After running inference, we remove the batch dimension and apply a sigmoid activation to obtain class probabilities. We then threshold the probabilities at 0.5 to generate the final binary segmentation map.
We compare the ground truth (pre_label) and the predicted segmentation (segment) for each tumor sub-region. Each sub-plot shows a specific channel corresponding to a tumor type: TC, WT, and ET. Here we visualize the 80th axial slice across the three channels.
The predicted output is a multi-channel binary map, where each channel corresponds to a specific tumor region. To visualize it against the original ground truth, we convert it into a single-channel label map. Here we assign: - Label 1 for Tumor Core (TC) - Label 2 for Whole Tumor (WT) - Label 4 for Enhancing Tumor (ET) The label values are chosen to match typical conventions used in medical segmentation benchmarks like BraTS.
Let's begin by examining the original input slices from the MRI scan. The input contains four channels corresponding to different MRI modalities: - FLAIR - T1 - T1CE (T1 with contrast enhancement) - T2 We display the same slice number across all modalities for comparison.
Next, we compare this input with the ground truth label and the predicted segmentation on the same slice. This provides visual insight into how well the model has localized and segmented the tumor regions.
Finally, create a clean GIF visualizer showing the input image, ground-truth label, and model prediction.
Prepare a visualization-friendly prediction map by remapping label values to a compact index range.
When you open the saved GIF, you should see a visualization similar to this.
