Path: blob/master/examples/vision/brain_tumor_segmentation.py
8146 views
"""1Title: 3D Multimodal Brain Tumor Segmentation2Author: [Mohammed Innat](https://www.linkedin.com/in/innat2k14/)3Date created: 2026/02/024Last modified: 2026/02/025Description: Implementing 3D semantic segmentation pipeline for medical imaging.6Accelerator: GPU7"""89"""10Brain tumor segmentation is a core task in medical image analysis, where the goal is to automatically identify and label different tumor sub-regions from 3D MRI scans. Accurate segmentation helps clinicians with diagnosis, treatment planning, and disease monitoring. In this tutorial, we focus on multimodal MRI-based brain tumor segmentation using the widely adopted **BraTS** (**Brain Tumor Segmentation**) dataset.1112## The BraTS Dataset1314The **BraTS** dataset provides multimodal 3D brain MRI scans, released as NIfTI files (`.nii.gz`). For each patient, four MRI modalities are available:1516- **T1** – native T1-weighted MRI17- **T1Gd** – post-contrast T1-weighted MRI18- **T2** – T2-weighted MRI19- **T2-FLAIR** – Fluid Attenuated Inversion Recovery MRI2021These scans are collected using different scanners and clinical protocols from 19 institutions, making the dataset diverse and realistic. More details about the dataset can be found in the official [BraTS documentation](https://www.med.upenn.edu/cbica/brats2020/data.html).22"""2324"""25## Segmentation Labels2627Each scan is manually annotated by **one to four expert raters**, following a standardized annotation protocol and reviewed by experienced neuroradiologists. The segmentation masks contain the following tumor sub-regions:2829- **NCR / NET (label 1)** – Necrotic and non-enhancing tumor core30- **ED (label 2)** – Peritumoral edema31- **ET (label 4)** – GD-enhancing tumor32- **0** – Background (non-tumor tissue)3334The data are released after preprocessing:3536- All modalities are **co-registered**37- Resampled to `1 mm³` isotropic resolution38- **Skull-stripped** for consistency39"""4041"""42## Dataset Format and TFRecord Conversion4344The original BraTS scans are provided in `.nii` format and can be accessed from Kaggle [here](https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation/). To enable **efficient training pipelines**, we convert the NIfTI files into **TFRecord** format:4546- The conversion process is documented [here](https://www.kaggle.com/code/ipythonx/brats-nii-to-tfrecord)47- The preprocessed TFRecord dataset is available [here](https://www.kaggle.com/datasets/ipythonx/brats2020)48- Each TFRecord file contains **up to 20 cases**4950Since BraTS does not provide publicly available ground-truth labels for validation or test sets, we will **hold out a subset of TFRecord files** from training for validation purposes.515253# What This Tutorial Covers5455In this tutorial, we provide a step-by-step, end-to-end workflow for brain tumor segmentation using [medicai](https://github.com/innat/medic-ai), a Keras-based medical imaging library with multi-backend support. We will walk through:56571. **Loading the Dataset**58- Read TFRecord files that contain `image`, `label`, and `affine` matrix information.59- Build efficient data pipelines using the `tf.data` API for training and evaluation.602. **Medical Image Preprocessing**61- Apply image transformations provided by `medicai` to prepare the data for model input.623. **Model Building**63- Construct a 3D segmentation model with [`SwinUNETR`](https://arxiv.org/abs/2201.01266) You can also experiment with other available 3D architectures, including [`UNETR`](https://arxiv.org/abs/2103.10504), [`SegFormer`](https://arxiv.org/abs/2404.10156), and [`UNETR++`](https://ieeexplore.ieee.org/document/10526382)..644. **Loss and Metrics Definition**65- Using Dice-based loss functions and segmentation metrics tailored for medical imaging665. **Model Evaluation**67- Performing inference on large 3D volumes using **sliding window inference**68- Computing per-class evaluation metrics696. **Visualization of Results**70- Visualizing predicted segmentation masks for qualitative analysis7172By the end of this tutorial, you will have a complete brain tumor segmentation pipeline, from data loading and preprocessing to model training, evaluation, and visualization using modern 3D deep learning techniques and the `medicai` framework.73"""7475"""76## Installation7778We will install the following packages: [`kagglehub`](https://github.com/Kaggle/kagglehub) for downloading the dataset from79Kaggle, and [`medicai`](https://github.com/innat/medic-ai) for accessing specialized methods for medical imaging, including 3D transformations, model architectures, loss functions, metrics, and other essential components.8081```shell82!pip install kagglehub -qU83!pip install git+https://github.com/innat/medic-ai.git -qU84```85"""86import os87import warnings8889warnings.filterwarnings("ignore")9091import shutil92import kagglehub93from IPython.display import clear_output9495if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:96kagglehub.login()9798"""99Download the dataset from kaggle.100"""101dataset_id = "ipythonx/brats2020"102destination_path = "brats2020_subset"103os.makedirs(destination_path, exist_ok=True)104105# Download the 3 shards: 0 and 1st for training set, 36th for validation set.106for i in [0, 1, 36]:107filename = f"training_shard_{i}.tfrec"108print(f"Downloading {filename}...")109path = kagglehub.dataset_download(dataset_id, path=filename)110shutil.move(path, destination_path)111112# Comment this line to keep the logs visible113clear_output()114115"""116## Imports117"""118os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax119120import numpy as np121import pandas as pd122123import keras124from keras import ops125import tensorflow as tf126127from matplotlib import pyplot as plt128import matplotlib.animation as animation129from matplotlib.colors import ListedColormap130131from medicai.callbacks import SlidingWindowInferenceCallback132from medicai.losses import BinaryDiceCELoss133from medicai.metrics import BinaryDiceMetric134from medicai.models import SwinUNETR135from medicai.transforms import (136Compose,137CropForeground,138NormalizeIntensity,139RandFlip,140RandShiftIntensity,141RandSpatialCrop,142TensorBundle,143)144from medicai.utils.inference import SlidingWindowInference145146# enable mixed precision147keras.mixed_precision.set_global_policy("mixed_float16")148149# reproducibility150keras.utils.set_random_seed(101)151152print(153f"keras backend: {keras.config.backend()}\n"154f"keras version: {keras.version()}\n"155f"tensorflow version: {tf.__version__}\n"156)157158"""159# Create Multi-label Brain Tumor Labels160161The BraTS segmentation task involves multiple tumor sub-regions, and it is formulated as a multi-label segmentation problem. The label combinations are used to define the following clinical regions of interest:162163```shell164- Tumor Core (TC): label = 1 or 4165- Whole Tumor (WT): label = 1 or 2 or 4166- Enhancing Tumor (ET): label = 4167```168169These region-wise groupings allow for evaluation across different tumor structures relevant for clinical assessment and treatment planning. A sample view is shown below, figure taken from [BraTS-benchmark](https://arxiv.org/abs/2107.02314) paper.170171172173## Managing Multi-Label Outputs with `TensorBundle`174175To organize and manage these multi-label segmentation targets, we will implement a custom transformation using [**TensorBundle**](https://github.com/innat/medic-ai/blob/2d2139020531acd1c2d41b07a10daf04ceb150f4/medicai/transforms/tensor_bundle.py#L9) from `medicai`. The `TensorBundle` is a lightweight container class designed to hold:176177- A dictionary of tensors (e.g., images, labels)178- Optional metadata associated with those tensors (e.g., affine matrices, spacing, original shapes)179180This design allows data and metadata to be passed together through the transformation pipeline in a structured and consistent way. Each `medicai` transformation expects inputs to be organized as `key:value` pairs, for example:181182```shell183meta = {"affine": affine}184data = {"image": image, "label": label}185```186187Using `TensorBundle` makes it easier to apply complex medical imaging transformations while preserving spatial and anatomical information throughout preprocessing and model inference.188"""189190191class ConvertToMultiChannelBasedOnBratsClasses:192"""193Convert labels to multi channels based on BRATS classes using TensorFlow.194195Label definitions:196- 1: necrotic and non-enhancing tumor core197- 2: peritumoral edema198- 4: GD-enhancing tumor199200Output channels:201- Channel 0 (TC): Tumor core (labels 1 or 4)202- Channel 1 (WT): Whole tumor (labels 1, 2, or 4)203- Channel 2 (ET): Enhancing tumor (label 4)204"""205206def __init__(self, keys):207self.keys = keys208209def __call__(self, inputs):210if isinstance(inputs, dict):211inputs = TensorBundle(inputs)212213for key in self.keys:214data = inputs[key]215216# TC: label == 1 or 4217tc = tf.logical_or(tf.equal(data, 1), tf.equal(data, 4))218219# WT: label == 1 or 2 or 4220wt = tf.logical_or(tc, tf.equal(data, 2))221222# ET: label == 4223et = tf.equal(data, 4)224225stacked = tf.stack(226[227tf.cast(tc, tf.float32),228tf.cast(wt, tf.float32),229tf.cast(et, tf.float32),230],231axis=-1,232)233234inputs[key] = stacked235return inputs236237238"""239## Transformation240241Each `medicai` transformation expects the input to have the shape `(depth, height, width, channel)`. The original `.nii` (and converted `.tfrecord`) format contains the input shape of `(height, width, depth)`. To make it compatible with `medicai`, we need to re-arrange the shape axes.242"""243244245def rearrange_shape(sample):246# unpack sample247image = sample["image"]248label = sample["label"]249affine = sample["affine"]250251# special case252image = tf.transpose(image, perm=[2, 1, 0, 3]) # whdc -> dhwc253label = tf.transpose(label, perm=[2, 1, 0]) # whd -> dhw254cols = tf.gather(affine, [2, 1, 0], axis=1) # (whd) -> (dhw)255affine = tf.concat([cols, affine[:, 3:]], axis=1)256257# update sample with new / updated tensor258sample["image"] = image259sample["label"] = label260sample["affine"] = affine261return sample262263264"""265Each transformation class of `medicai` expects input as either a dictionary or a `TensorBundle` object, as discussed earlier. When a dictionary of input data (along with metadata) is passed, it is automatically wrapped into a `TensorBundle` instance. The examples below demonstrate how transformations are used in this way.266"""267268num_classes = 3269epochs = 4270input_shape = (96, 96, 96, 4)271272273def train_transformation(sample):274meta = {"affine": sample["affine"]}275data = {"image": sample["image"], "label": sample["label"]}276277pipeline = Compose(278[279ConvertToMultiChannelBasedOnBratsClasses(keys=["label"]),280CropForeground(281keys=("image", "label"),282source_key="image",283k_divisible=input_shape[:3],284),285RandSpatialCrop(286keys=["image", "label"], roi_size=input_shape[:3], random_size=False287),288RandFlip(keys=["image", "label"], spatial_axis=[0], prob=0.5),289RandFlip(keys=["image", "label"], spatial_axis=[1], prob=0.5),290RandFlip(keys=["image", "label"], spatial_axis=[2], prob=0.5),291NormalizeIntensity(keys=["image"], nonzero=True, channel_wise=True),292RandShiftIntensity(keys=["image"], offsets=0.10, prob=1.0),293]294)295result = pipeline(data, meta)296return result["image"], result["label"]297298299def val_transformation(sample):300meta = {"affine": sample["affine"]}301data = {"image": sample["image"], "label": sample["label"]}302303pipeline = Compose(304[305ConvertToMultiChannelBasedOnBratsClasses(keys=["label"]),306NormalizeIntensity(keys=["image"], nonzero=True, channel_wise=True),307]308)309result = pipeline(data, meta)310return result["image"], result["label"]311312313"""314## The `tfrecord` parser315"""316317318def parse_tfrecord_fn(example_proto):319feature_description = {320# Image raw data321"flair_raw": tf.io.FixedLenFeature([], tf.string),322"t1_raw": tf.io.FixedLenFeature([], tf.string),323"t1ce_raw": tf.io.FixedLenFeature([], tf.string),324"t2_raw": tf.io.FixedLenFeature([], tf.string),325"label_raw": tf.io.FixedLenFeature([], tf.string),326# Image shape327"flair_shape": tf.io.FixedLenFeature([3], tf.int64),328"t1_shape": tf.io.FixedLenFeature([3], tf.int64),329"t1ce_shape": tf.io.FixedLenFeature([3], tf.int64),330"t2_shape": tf.io.FixedLenFeature([3], tf.int64),331"label_shape": tf.io.FixedLenFeature([3], tf.int64),332# Affine matrices (4x4 = 16 values)333"flair_affine": tf.io.FixedLenFeature([16], tf.float32),334"t1_affine": tf.io.FixedLenFeature([16], tf.float32),335"t1ce_affine": tf.io.FixedLenFeature([16], tf.float32),336"t2_affine": tf.io.FixedLenFeature([16], tf.float32),337"label_affine": tf.io.FixedLenFeature([16], tf.float32),338# Voxel Spacing (pixdim)339"flair_pixdim": tf.io.FixedLenFeature([8], tf.float32),340"t1_pixdim": tf.io.FixedLenFeature([8], tf.float32),341"t1ce_pixdim": tf.io.FixedLenFeature([8], tf.float32),342"t2_pixdim": tf.io.FixedLenFeature([8], tf.float32),343"label_pixdim": tf.io.FixedLenFeature([8], tf.float32),344# Filenames345"flair_filename": tf.io.FixedLenFeature([], tf.string),346"t1_filename": tf.io.FixedLenFeature([], tf.string),347"t1ce_filename": tf.io.FixedLenFeature([], tf.string),348"t2_filename": tf.io.FixedLenFeature([], tf.string),349"label_filename": tf.io.FixedLenFeature([], tf.string),350}351352example = tf.io.parse_single_example(example_proto, feature_description)353354# Decode image and label data355flair = tf.io.decode_raw(example["flair_raw"], tf.float32)356t1 = tf.io.decode_raw(example["t1_raw"], tf.float32)357t1ce = tf.io.decode_raw(example["t1ce_raw"], tf.float32)358t2 = tf.io.decode_raw(example["t2_raw"], tf.float32)359label = tf.io.decode_raw(example["label_raw"], tf.float32)360361# Reshape to original dimensions362flair = tf.reshape(flair, example["flair_shape"])363t1 = tf.reshape(t1, example["t1_shape"])364t1ce = tf.reshape(t1ce, example["t1ce_shape"])365t2 = tf.reshape(t2, example["t2_shape"])366label = tf.reshape(label, example["label_shape"])367368# Decode affine matrices369flair_affine = tf.reshape(example["flair_affine"], (4, 4))370t1_affine = tf.reshape(example["t1_affine"], (4, 4))371t1ce_affine = tf.reshape(example["t1ce_affine"], (4, 4))372t2_affine = tf.reshape(example["t2_affine"], (4, 4))373label_affine = tf.reshape(example["label_affine"], (4, 4))374375# add channel axis376flair = flair[..., None]377t1 = t1[..., None]378t1ce = t1ce[..., None]379t2 = t2[..., None]380image = tf.concat([flair, t1, t1ce, t2], axis=-1)381382return {383"image": image,384"label": label,385"affine": flair_affine, # Since affine is the same for all386}387388389"""390## Dataloader391"""392393394def train_dataloader(395tfrecord_datalist,396batch_size=1,397shuffle_buffer=100,398):399dataset = tf.data.TFRecordDataset(tfrecord_datalist)400dataset = dataset.shuffle(shuffle_buffer)401dataset = dataset.map(402parse_tfrecord_fn,403num_parallel_calls=tf.data.AUTOTUNE,404)405dataset = dataset.map(406rearrange_shape,407num_parallel_calls=tf.data.AUTOTUNE,408)409dataset = dataset.map(410train_transformation,411num_parallel_calls=tf.data.AUTOTUNE,412)413dataset = dataset.batch(414batch_size,415drop_remainder=True,416)417dataset = dataset.prefetch(tf.data.AUTOTUNE)418return dataset419420421def val_dataloader(422tfrecord_datalist,423batch_size=1,424):425dataset = tf.data.TFRecordDataset(tfrecord_datalist)426dataset = dataset.map(427parse_tfrecord_fn,428num_parallel_calls=tf.data.AUTOTUNE,429)430dataset = dataset.map(431rearrange_shape,432num_parallel_calls=tf.data.AUTOTUNE,433)434dataset = dataset.map(435val_transformation,436num_parallel_calls=tf.data.AUTOTUNE,437)438dataset = dataset.batch(batch_size)439dataset = dataset.prefetch(tf.data.AUTOTUNE)440return dataset441442443"""444The training batch size can be set to more than 1 depending on the environment and available resources. However, we intentionally keep the validation batch size as 1 to handle variable-sized samples more flexibly. While padded or ragged batches are alternative options, a batch size of 1 ensures simplicity and consistency during evaluation, especially for 3D medical data.445"""446447tfrecord_pattern = "brats2020_subset/training_shard_*.tfrec"448datalist = sorted(449tf.io.gfile.glob(tfrecord_pattern),450key=lambda x: int(x.split("_")[-1].split(".")[0]),451)452453train_datalist = datalist[:-1]454val_datalist = datalist[-1:]455print(len(train_datalist), len(val_datalist))456457train_ds = train_dataloader(train_datalist, batch_size=1)458val_ds = val_dataloader(val_datalist, batch_size=1)459460"""461**sanity check**: Fetch a single validation sample to inspect its shape and values.462"""463464val_x, val_y = next(iter(val_ds))465test_image = val_x.numpy().squeeze()466test_mask = val_y.numpy().squeeze()467print(test_image.shape, test_mask.shape, np.unique(test_mask))468print(test_image.min(), test_image.max())469470"""471**sanity check**: Visualize the middle slice of the image and its corresponding label.472"""473474slice_no = test_image.shape[0] // 2475476fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))477ax1.imshow(test_image[slice_no], cmap="gray")478ax1.set_title(f"Image shape: {test_image.shape}")479ax2.imshow(test_mask[slice_no])480ax2.set_title(f"Label shape: {test_mask.shape}")481plt.show()482483"""484**sanity check**: Visualize sample image and label channels at middle slice index.485"""486487print(f"image shape: {test_image.shape}")488plt.figure("image", (24, 6))489for i in range(4):490plt.subplot(1, 4, i + 1)491plt.title(f"image channel {i}")492plt.imshow(test_image[slice_no, :, :, i], cmap="gray")493plt.show()494495496print(f"label shape: {test_mask.shape}")497plt.figure("label", (18, 6))498for i in range(3):499plt.subplot(1, 3, i + 1)500plt.title(f"label channel {i}")501plt.imshow(test_mask[slice_no, :, :, i])502plt.show()503504"""505## Model506507We will be using the 3D model architecture Swin UNEt TRansformers, i.e., [`SwinUNETR`](https://arxiv.org/abs/2201.01266). It was used in the BraTS 2021 segmentation challenge by NVIDIA. The model was among the top-performing methods. It uses a Swin Transformer encoder to extract features at five different resolutions. A CNN-based decoder is connected to each resolution using skip connections.508509The BraTS dataset provides four input modalities: `flair`, `t1`, `t1ce`, and `t2` and three multi-label outputs: `tumor-core`, `whole-tumor`, and `enhancing-tumor`. Accordingly, we will initiate the model with `4` input channels and `3` output channels.510511512"""513"""514```shell515# # check available models516# medicai.models.list_models()517```518"""519520model = SwinUNETR(521encoder_name="swin_tiny_v2",522input_shape=input_shape,523num_classes=num_classes,524classifier_activation=None,525)526527model.compile(528optimizer=keras.optimizers.AdamW(529learning_rate=1e-4,530weight_decay=1e-5,531),532loss=BinaryDiceCELoss(533from_logits=True,534num_classes=num_classes,535),536metrics=[537BinaryDiceMetric(538from_logits=True,539ignore_empty=True,540num_classes=num_classes,541name="dice",542),543BinaryDiceMetric(544from_logits=True,545ignore_empty=True,546target_class_ids=[0],547num_classes=num_classes,548name="dice_tc",549),550BinaryDiceMetric(551from_logits=True,552ignore_empty=True,553target_class_ids=[1],554num_classes=num_classes,555name="dice_wt",556),557BinaryDiceMetric(558from_logits=True,559ignore_empty=True,560target_class_ids=[2],561num_classes=num_classes,562name="dice_et",563),564],565)566567# ALERT: This `instance_describe` attributes available in medicai.568try:569print(model.instance_describe())570except AttributeError:571pass572573"""574## Callback575576We will be using sliding window inference callback from `medicai` to perform validation at certain interval or epoch during training. Based on the number of epoch size, we should set `interval` accordingly. For example, if epoch is set 15 and we want to evaluate model on validation set every 5 epoch, then we should set `interval` to 5.577"""578579swi_callback_metric = BinaryDiceMetric(580from_logits=True,581ignore_empty=True,582num_classes=num_classes,583name="val_dice",584)585586swi_callback = SlidingWindowInferenceCallback(587model,588dataset=val_ds,589metrics=swi_callback_metric,590num_classes=num_classes,591interval=2,592overlap=0.5,593roi_size=input_shape[:3],594sw_batch_size=4,595mode="gaussian",596save_path="brats.model.weights.h5",597)598599"""600## Training601602Set more epoch for better optimization.603"""604605history = model.fit(train_ds, epochs=epochs, callbacks=[swi_callback])606607# Comment this line to keep the logs visible608clear_output()609610"""611Let’s take a quick look at how our model performed during training. We will first print the available metrics recorded in the training history, save them to a CSV file for future reference, and then visualize them to better understand the model’s learning progress over epochs.612"""613614615def plot_training_history(history_df):616metrics = history_df.columns617n_metrics = len(metrics)618619n_rows = 2620n_cols = (n_metrics + 1) // 2 # ceiling division for columns621622plt.figure(figsize=(5 * n_cols, 5 * n_rows))623624for idx, metric in enumerate(metrics):625plt.subplot(n_rows, n_cols, idx + 1)626plt.plot(history_df[metric], label=metric, marker="o")627plt.title(metric)628plt.xlabel("Epoch")629plt.ylabel("Value")630plt.grid(True)631plt.legend()632633plt.tight_layout()634plt.show()635636637print(model.history.history.keys())638his_csv = pd.DataFrame(model.history.history)639his_csv.to_csv("brats.history.csv")640plot_training_history(his_csv)641642"""643## Evaluation644645In this [Kaggle notebook](https://www.kaggle.com/code/ipythonx/3d-brats-segmentation-in-keras-multi-gpu/) (version 5), we trained the model on the entire dataset for approximately `30` epochs. The resulting weights will be used for further evaluation. Note that the validation set used in both here and Kaggle notebook are the same: `training_shard_36.tfrec`, which contains `8` samples.646"""647648model_weight = kagglehub.model_download(649"ipythonx/bratsmodel/keras/default", path="brats.model.weights.h5"650)651print("\nPath to model files:", model_weight)652653model.load_weights(model_weight)654655"""656In this section, we perform sliding window inference on the validation dataset and compute Dice scores for overall segmentation quality as well as specific tumor subregions:657- Tumor Core (TC)658- Whole Tumor (WT)659- Enhancing Tumor (ET)660"""661662swi = SlidingWindowInference(663model,664num_classes=num_classes,665roi_size=input_shape[:3],666sw_batch_size=4,667overlap=0.5,668mode="gaussian",669)670671dice = BinaryDiceMetric(672from_logits=True,673ignore_empty=True,674num_classes=num_classes,675name="dice",676)677dice_tc = BinaryDiceMetric(678from_logits=True,679ignore_empty=True,680target_class_ids=[0],681num_classes=num_classes,682name="dice_tc",683)684dice_wt = BinaryDiceMetric(685from_logits=True,686ignore_empty=True,687target_class_ids=[1],688num_classes=num_classes,689name="dice_wt",690)691dice_et = BinaryDiceMetric(692from_logits=True,693ignore_empty=True,694target_class_ids=[2],695num_classes=num_classes,696name="dice_et",697)698699"""700Due to the variable size, and larger size of the validation data, we iterate over the validation dataloader. The sliding window inference handles input patches and computes the predictions for each batch.701"""702703dice.reset_state()704dice_tc.reset_state()705dice_wt.reset_state()706dice_et.reset_state()707708for sample in val_ds:709x, y = sample710output = swi(x)711dice.update_state(y, output)712dice_tc.update_state(y, output)713dice_wt.update_state(y, output)714dice_et.update_state(y, output)715716dice_score = float(ops.convert_to_numpy(dice.result()))717dice_score_tc = float(ops.convert_to_numpy(dice_tc.result()))718dice_score_wt = float(ops.convert_to_numpy(dice_wt.result()))719dice_score_et = float(ops.convert_to_numpy(dice_et.result()))720721# Comment this line to keep the logs visible722clear_output()723724print(f"Dice Score: {dice_score:.4f}")725print(f"Dice Score on tumor core (TC): {dice_score_tc:.4f}")726print(f"Dice Score on whole tumor (WT): {dice_score_wt:.4f}")727print(f"Dice Score on enhancing tumor (ET): {dice_score_et:.4f}")728729"""730## Analyse and Visualize731732Let's analyse the model predictions and visualize them. First, we will implement the test transformation pipeline. This is same as validation transformation.733"""734735736def test_transformation(sample):737return val_transformation(sample)738739740"""741Let's load the `tfrecord` file and check its properties.742"""743744index = 0745dataset = tf.data.TFRecordDataset(val_datalist[index])746dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)747dataset = dataset.map(rearrange_shape, num_parallel_calls=tf.data.AUTOTUNE)748749sample = next(iter(dataset))750orig_image = sample["image"]751orig_label = sample["label"]752print(orig_image.shape, orig_label.shape, np.unique(orig_label))753754"""755Run the transformation to prepare the inputs.756"""757758pre_image, pre_label = test_transformation(sample)759print(pre_image.shape, pre_label.shape)760761"""762Pass the preprocessed sample to the inference object, ensuring that a batch axis is added to the input beforehand.763"""764765y_pred = swi(pre_image[None, ...])766767# Comment this line to keep the logs visible768clear_output()769770print(y_pred.shape)771772"""773After running inference, we remove the batch dimension and apply a `sigmoid` activation to obtain class probabilities. We then threshold the probabilities at `0.5` to generate the final binary segmentation map.774"""775776y_pred_logits = y_pred.squeeze(axis=0)777y_pred_prob = ops.convert_to_numpy(ops.sigmoid(y_pred_logits))778segment = (y_pred_prob > 0.5).astype(int)779print(segment.shape, np.unique(segment))780781782"""783We compare the ground truth (`pre_label`) and the predicted segmentation (`segment`) for each tumor sub-region. Each sub-plot shows a specific channel corresponding to a tumor type: TC, WT, and ET. Here we visualize the `80th` axial slice across the three channels.784"""785786label_map = {0: "TC", 1: "WT", 2: "ET"}787788plt.figure(figsize=(16, 4))789for i in range(pre_label.shape[-1]):790plt.subplot(1, 3, i + 1)791plt.title(f"label channel {label_map[i]}")792plt.imshow(pre_label[80, :, :, i])793plt.show()794795plt.figure(figsize=(16, 4))796for i in range(3):797plt.subplot(1, 3, i + 1)798plt.title(f"pred channel {label_map[i]}")799plt.imshow(segment[80, :, :, i])800plt.show()801802803"""804The predicted output is a multi-channel binary map, where each channel corresponds to a specific tumor region. To visualize it against the original ground truth, we convert it into a single-channel label map. Here we assign:805- Label 1 for Tumor Core (TC)806- Label 2 for Whole Tumor (WT)807- Label 4 for Enhancing Tumor (ET)808The label values are chosen to match typical conventions used in medical segmentation benchmarks like BraTS.809"""810811prediction = np.zeros(812(segment.shape[0], segment.shape[1], segment.shape[2]), dtype="float32"813)814prediction[segment[..., 1] == 1] = 2815prediction[segment[..., 0] == 1] = 1816prediction[segment[..., 2] == 1] = 4817818print("label ", orig_label.shape, np.unique(orig_label))819print("predicted ", prediction.shape, np.unique(prediction))820821822"""823Let's begin by examining the original input slices from the MRI scan. The input contains four channels corresponding to different MRI modalities:824- FLAIR825- T1826- T1CE (T1 with contrast enhancement)827- T2828We display the same slice number across all modalities for comparison.829"""830831slice_map = {0: "flair", 1: "t1", 2: "t1ce", 3: "t2"}832slice_num = 75833834plt.figure(figsize=(16, 4))835for i in range(orig_image.shape[-1]):836plt.subplot(1, 4, i + 1)837plt.title(f"Original channel: {slice_map[i]}")838plt.imshow(orig_image[slice_num, :, :, i], cmap="gray")839840plt.tight_layout()841plt.show()842843"""844Next, we compare this input with the ground truth label and the predicted segmentation on the same slice. This provides visual insight into how well the model has localized and segmented the tumor regions.845"""846847num_channels = orig_image.shape[-1]848plt.figure("image", (15, 15))849850# plotting image, label and prediction851plt.subplot(3, num_channels, num_channels + 1)852plt.title("image")853plt.imshow(orig_image[slice_num, :, :, 0], cmap="gray")854855plt.subplot(3, num_channels, num_channels + 2)856plt.title("label")857plt.imshow(orig_label[slice_num, :, :])858859plt.subplot(3, num_channels, num_channels + 3)860plt.title("prediction")861plt.imshow(prediction[slice_num, :, :])862863plt.tight_layout()864plt.show()865866"""867Finally, create a clean GIF visualizer showing the input image, ground-truth label, and model prediction.868"""869# The input volume contains large black margins, so we crop870# the foreground region of interest (ROI).871crop_foreground = CropForeground(872keys=("image", "label", "prediction"), source_key="image"873)874875data = {876"image": orig_image,877"label": orig_label[..., None],878"prediction": prediction[..., None],879}880results = crop_foreground(data)881crop_orig_image = results["image"]882crop_orig_label = results["label"]883crop_prediction = results["prediction"]884885"""886Prepare a visualization-friendly prediction map by remapping label values to a compact index range.887"""888889viz_pred = np.zeros_like(crop_prediction, dtype="uint8")890viz_pred[crop_prediction == 1] = 1891viz_pred[crop_prediction == 2] = 2892viz_pred[crop_prediction == 4] = 3893894# Colormap for background, tumor core, edema, and enhancing regions895cmap = ListedColormap(896[897"#000000", # background898"#E57373", # muted red899"#64B5F6", # muted blue900"#81C784", # muted green901]902)903904# Create side-by-side views for input, label, and prediction905fig, axes = plt.subplots(1, 3, figsize=(10, 4))906ax_img, ax_lbl, ax_pred = axes907908img_im = ax_img.imshow(crop_orig_image[0, :, :, 0], cmap="gray")909lbl_im = ax_lbl.imshow(910crop_orig_label[0], vmin=0, vmax=3, cmap=cmap, interpolation="nearest"911)912pred_im = ax_pred.imshow(913viz_pred[0], vmin=0, vmax=3, cmap=cmap, interpolation="nearest"914)915916# Tight layout for a compact GIF917plt.subplots_adjust(left=0.01, right=0.99, bottom=0.02, top=0.8, wspace=0.01)918919for ax, t in zip(axes, ["FLAIR", "Label", "Prediction"]):920ax.set_title(t, fontsize=19, pad=10)921ax.axis("off")922ax.set_adjustable("box")923924925def update(i):926img_im.set_data(crop_orig_image[i, :, :, 0])927lbl_im.set_data(crop_orig_label[i])928pred_im.set_data(viz_pred[i])929fig.suptitle(f"Slice {i}", fontsize=14)930return img_im, lbl_im, pred_im931932933ani = animation.FuncAnimation(934fig, update, frames=crop_orig_image.shape[0], interval=120935)936ani.save(937"segmentation_slices.gif",938writer="pillow",939dpi=100,940)941plt.close(fig)942943"""944When you open the saved GIF, you should see a visualization similar to this.945946947"""948949"""950## Additional Resources951952- [BraTS Segmentation on Multi-GPU](https://www.kaggle.com/code/ipythonx/3d-brats-segmentation-in-keras-multi-gpu)953- [BraTS Segmentation on TPU-VM](https://www.kaggle.com/code/ipythonx/3d-brats-segmentation-in-keras-tpu-vm)954- [BraTS .nii to TFRecord](https://www.kaggle.com/code/ipythonx/brats-nii-to-tfrecord)955- [Covid-19 Segmentation](https://www.kaggle.com/code/ipythonx/medicai-covid-19-3d-image-segmentation)956- [3D Multi-organ Segmentation](https://www.kaggle.com/code/ipythonx/medicai-3d-btcv-segmentation-in-keras)957- [Spleen 3D segmentation](https://www.kaggle.com/code/ipythonx/medicai-spleen-3d-segmentation-in-keras)958- [3D Medical Image Transformation](https://www.kaggle.com/code/ipythonx/medicai-3d-medical-image-transformation)959"""960961962