Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/brain_tumor_segmentation.py
8146 views
1
"""
2
Title: 3D Multimodal Brain Tumor Segmentation
3
Author: [Mohammed Innat](https://www.linkedin.com/in/innat2k14/)
4
Date created: 2026/02/02
5
Last modified: 2026/02/02
6
Description: Implementing 3D semantic segmentation pipeline for medical imaging.
7
Accelerator: GPU
8
"""
9
10
"""
11
Brain tumor segmentation is a core task in medical image analysis, where the goal is to automatically identify and label different tumor sub-regions from 3D MRI scans. Accurate segmentation helps clinicians with diagnosis, treatment planning, and disease monitoring. In this tutorial, we focus on multimodal MRI-based brain tumor segmentation using the widely adopted **BraTS** (**Brain Tumor Segmentation**) dataset.
12
13
## The BraTS Dataset
14
15
The **BraTS** dataset provides multimodal 3D brain MRI scans, released as NIfTI files (`.nii.gz`). For each patient, four MRI modalities are available:
16
17
- **T1** – native T1-weighted MRI
18
- **T1Gd** – post-contrast T1-weighted MRI
19
- **T2** – T2-weighted MRI
20
- **T2-FLAIR** – Fluid Attenuated Inversion Recovery MRI
21
22
These scans are collected using different scanners and clinical protocols from 19 institutions, making the dataset diverse and realistic. More details about the dataset can be found in the official [BraTS documentation](https://www.med.upenn.edu/cbica/brats2020/data.html).
23
"""
24
25
"""
26
## Segmentation Labels
27
28
Each scan is manually annotated by **one to four expert raters**, following a standardized annotation protocol and reviewed by experienced neuroradiologists. The segmentation masks contain the following tumor sub-regions:
29
30
- **NCR / NET (label 1)** – Necrotic and non-enhancing tumor core
31
- **ED (label 2)** – Peritumoral edema
32
- **ET (label 4)** – GD-enhancing tumor
33
- **0** – Background (non-tumor tissue)
34
35
The data are released after preprocessing:
36
37
- All modalities are **co-registered**
38
- Resampled to `1 mm³` isotropic resolution
39
- **Skull-stripped** for consistency
40
"""
41
42
"""
43
## Dataset Format and TFRecord Conversion
44
45
The original BraTS scans are provided in `.nii` format and can be accessed from Kaggle [here](https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation/). To enable **efficient training pipelines**, we convert the NIfTI files into **TFRecord** format:
46
47
- The conversion process is documented [here](https://www.kaggle.com/code/ipythonx/brats-nii-to-tfrecord)
48
- The preprocessed TFRecord dataset is available [here](https://www.kaggle.com/datasets/ipythonx/brats2020)
49
- Each TFRecord file contains **up to 20 cases**
50
51
Since BraTS does not provide publicly available ground-truth labels for validation or test sets, we will **hold out a subset of TFRecord files** from training for validation purposes.
52
53
54
# What This Tutorial Covers
55
56
In this tutorial, we provide a step-by-step, end-to-end workflow for brain tumor segmentation using [medicai](https://github.com/innat/medic-ai), a Keras-based medical imaging library with multi-backend support. We will walk through:
57
58
1. **Loading the Dataset**
59
- Read TFRecord files that contain `image`, `label`, and `affine` matrix information.
60
- Build efficient data pipelines using the `tf.data` API for training and evaluation.
61
2. **Medical Image Preprocessing**
62
- Apply image transformations provided by `medicai` to prepare the data for model input.
63
3. **Model Building**
64
- Construct a 3D segmentation model with [`SwinUNETR`](https://arxiv.org/abs/2201.01266) You can also experiment with other available 3D architectures, including [`UNETR`](https://arxiv.org/abs/2103.10504), [`SegFormer`](https://arxiv.org/abs/2404.10156), and [`UNETR++`](https://ieeexplore.ieee.org/document/10526382)..
65
4. **Loss and Metrics Definition**
66
- Using Dice-based loss functions and segmentation metrics tailored for medical imaging
67
5. **Model Evaluation**
68
- Performing inference on large 3D volumes using **sliding window inference**
69
- Computing per-class evaluation metrics
70
6. **Visualization of Results**
71
- Visualizing predicted segmentation masks for qualitative analysis
72
73
By the end of this tutorial, you will have a complete brain tumor segmentation pipeline, from data loading and preprocessing to model training, evaluation, and visualization using modern 3D deep learning techniques and the `medicai` framework.
74
"""
75
76
"""
77
## Installation
78
79
We will install the following packages: [`kagglehub`](https://github.com/Kaggle/kagglehub) for downloading the dataset from
80
Kaggle, and [`medicai`](https://github.com/innat/medic-ai) for accessing specialized methods for medical imaging, including 3D transformations, model architectures, loss functions, metrics, and other essential components.
81
82
```shell
83
!pip install kagglehub -qU
84
!pip install git+https://github.com/innat/medic-ai.git -qU
85
```
86
"""
87
import os
88
import warnings
89
90
warnings.filterwarnings("ignore")
91
92
import shutil
93
import kagglehub
94
from IPython.display import clear_output
95
96
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
97
kagglehub.login()
98
99
"""
100
Download the dataset from kaggle.
101
"""
102
dataset_id = "ipythonx/brats2020"
103
destination_path = "brats2020_subset"
104
os.makedirs(destination_path, exist_ok=True)
105
106
# Download the 3 shards: 0 and 1st for training set, 36th for validation set.
107
for i in [0, 1, 36]:
108
filename = f"training_shard_{i}.tfrec"
109
print(f"Downloading {filename}...")
110
path = kagglehub.dataset_download(dataset_id, path=filename)
111
shutil.move(path, destination_path)
112
113
# Comment this line to keep the logs visible
114
clear_output()
115
116
"""
117
## Imports
118
"""
119
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax
120
121
import numpy as np
122
import pandas as pd
123
124
import keras
125
from keras import ops
126
import tensorflow as tf
127
128
from matplotlib import pyplot as plt
129
import matplotlib.animation as animation
130
from matplotlib.colors import ListedColormap
131
132
from medicai.callbacks import SlidingWindowInferenceCallback
133
from medicai.losses import BinaryDiceCELoss
134
from medicai.metrics import BinaryDiceMetric
135
from medicai.models import SwinUNETR
136
from medicai.transforms import (
137
Compose,
138
CropForeground,
139
NormalizeIntensity,
140
RandFlip,
141
RandShiftIntensity,
142
RandSpatialCrop,
143
TensorBundle,
144
)
145
from medicai.utils.inference import SlidingWindowInference
146
147
# enable mixed precision
148
keras.mixed_precision.set_global_policy("mixed_float16")
149
150
# reproducibility
151
keras.utils.set_random_seed(101)
152
153
print(
154
f"keras backend: {keras.config.backend()}\n"
155
f"keras version: {keras.version()}\n"
156
f"tensorflow version: {tf.__version__}\n"
157
)
158
159
"""
160
# Create Multi-label Brain Tumor Labels
161
162
The BraTS segmentation task involves multiple tumor sub-regions, and it is formulated as a multi-label segmentation problem. The label combinations are used to define the following clinical regions of interest:
163
164
```shell
165
- Tumor Core (TC): label = 1 or 4
166
- Whole Tumor (WT): label = 1 or 2 or 4
167
- Enhancing Tumor (ET): label = 4
168
```
169
170
These region-wise groupings allow for evaluation across different tumor structures relevant for clinical assessment and treatment planning. A sample view is shown below, figure taken from [BraTS-benchmark](https://arxiv.org/abs/2107.02314) paper.
171
172
![](https://i.imgur.com/Agnwpxm.png)
173
174
## Managing Multi-Label Outputs with `TensorBundle`
175
176
To organize and manage these multi-label segmentation targets, we will implement a custom transformation using [**TensorBundle**](https://github.com/innat/medic-ai/blob/2d2139020531acd1c2d41b07a10daf04ceb150f4/medicai/transforms/tensor_bundle.py#L9) from `medicai`. The `TensorBundle` is a lightweight container class designed to hold:
177
178
- A dictionary of tensors (e.g., images, labels)
179
- Optional metadata associated with those tensors (e.g., affine matrices, spacing, original shapes)
180
181
This design allows data and metadata to be passed together through the transformation pipeline in a structured and consistent way. Each `medicai` transformation expects inputs to be organized as `key:value` pairs, for example:
182
183
```shell
184
meta = {"affine": affine}
185
data = {"image": image, "label": label}
186
```
187
188
Using `TensorBundle` makes it easier to apply complex medical imaging transformations while preserving spatial and anatomical information throughout preprocessing and model inference.
189
"""
190
191
192
class ConvertToMultiChannelBasedOnBratsClasses:
193
"""
194
Convert labels to multi channels based on BRATS classes using TensorFlow.
195
196
Label definitions:
197
- 1: necrotic and non-enhancing tumor core
198
- 2: peritumoral edema
199
- 4: GD-enhancing tumor
200
201
Output channels:
202
- Channel 0 (TC): Tumor core (labels 1 or 4)
203
- Channel 1 (WT): Whole tumor (labels 1, 2, or 4)
204
- Channel 2 (ET): Enhancing tumor (label 4)
205
"""
206
207
def __init__(self, keys):
208
self.keys = keys
209
210
def __call__(self, inputs):
211
if isinstance(inputs, dict):
212
inputs = TensorBundle(inputs)
213
214
for key in self.keys:
215
data = inputs[key]
216
217
# TC: label == 1 or 4
218
tc = tf.logical_or(tf.equal(data, 1), tf.equal(data, 4))
219
220
# WT: label == 1 or 2 or 4
221
wt = tf.logical_or(tc, tf.equal(data, 2))
222
223
# ET: label == 4
224
et = tf.equal(data, 4)
225
226
stacked = tf.stack(
227
[
228
tf.cast(tc, tf.float32),
229
tf.cast(wt, tf.float32),
230
tf.cast(et, tf.float32),
231
],
232
axis=-1,
233
)
234
235
inputs[key] = stacked
236
return inputs
237
238
239
"""
240
## Transformation
241
242
Each `medicai` transformation expects the input to have the shape `(depth, height, width, channel)`. The original `.nii` (and converted `.tfrecord`) format contains the input shape of `(height, width, depth)`. To make it compatible with `medicai`, we need to re-arrange the shape axes.
243
"""
244
245
246
def rearrange_shape(sample):
247
# unpack sample
248
image = sample["image"]
249
label = sample["label"]
250
affine = sample["affine"]
251
252
# special case
253
image = tf.transpose(image, perm=[2, 1, 0, 3]) # whdc -> dhwc
254
label = tf.transpose(label, perm=[2, 1, 0]) # whd -> dhw
255
cols = tf.gather(affine, [2, 1, 0], axis=1) # (whd) -> (dhw)
256
affine = tf.concat([cols, affine[:, 3:]], axis=1)
257
258
# update sample with new / updated tensor
259
sample["image"] = image
260
sample["label"] = label
261
sample["affine"] = affine
262
return sample
263
264
265
"""
266
Each transformation class of `medicai` expects input as either a dictionary or a `TensorBundle` object, as discussed earlier. When a dictionary of input data (along with metadata) is passed, it is automatically wrapped into a `TensorBundle` instance. The examples below demonstrate how transformations are used in this way.
267
"""
268
269
num_classes = 3
270
epochs = 4
271
input_shape = (96, 96, 96, 4)
272
273
274
def train_transformation(sample):
275
meta = {"affine": sample["affine"]}
276
data = {"image": sample["image"], "label": sample["label"]}
277
278
pipeline = Compose(
279
[
280
ConvertToMultiChannelBasedOnBratsClasses(keys=["label"]),
281
CropForeground(
282
keys=("image", "label"),
283
source_key="image",
284
k_divisible=input_shape[:3],
285
),
286
RandSpatialCrop(
287
keys=["image", "label"], roi_size=input_shape[:3], random_size=False
288
),
289
RandFlip(keys=["image", "label"], spatial_axis=[0], prob=0.5),
290
RandFlip(keys=["image", "label"], spatial_axis=[1], prob=0.5),
291
RandFlip(keys=["image", "label"], spatial_axis=[2], prob=0.5),
292
NormalizeIntensity(keys=["image"], nonzero=True, channel_wise=True),
293
RandShiftIntensity(keys=["image"], offsets=0.10, prob=1.0),
294
]
295
)
296
result = pipeline(data, meta)
297
return result["image"], result["label"]
298
299
300
def val_transformation(sample):
301
meta = {"affine": sample["affine"]}
302
data = {"image": sample["image"], "label": sample["label"]}
303
304
pipeline = Compose(
305
[
306
ConvertToMultiChannelBasedOnBratsClasses(keys=["label"]),
307
NormalizeIntensity(keys=["image"], nonzero=True, channel_wise=True),
308
]
309
)
310
result = pipeline(data, meta)
311
return result["image"], result["label"]
312
313
314
"""
315
## The `tfrecord` parser
316
"""
317
318
319
def parse_tfrecord_fn(example_proto):
320
feature_description = {
321
# Image raw data
322
"flair_raw": tf.io.FixedLenFeature([], tf.string),
323
"t1_raw": tf.io.FixedLenFeature([], tf.string),
324
"t1ce_raw": tf.io.FixedLenFeature([], tf.string),
325
"t2_raw": tf.io.FixedLenFeature([], tf.string),
326
"label_raw": tf.io.FixedLenFeature([], tf.string),
327
# Image shape
328
"flair_shape": tf.io.FixedLenFeature([3], tf.int64),
329
"t1_shape": tf.io.FixedLenFeature([3], tf.int64),
330
"t1ce_shape": tf.io.FixedLenFeature([3], tf.int64),
331
"t2_shape": tf.io.FixedLenFeature([3], tf.int64),
332
"label_shape": tf.io.FixedLenFeature([3], tf.int64),
333
# Affine matrices (4x4 = 16 values)
334
"flair_affine": tf.io.FixedLenFeature([16], tf.float32),
335
"t1_affine": tf.io.FixedLenFeature([16], tf.float32),
336
"t1ce_affine": tf.io.FixedLenFeature([16], tf.float32),
337
"t2_affine": tf.io.FixedLenFeature([16], tf.float32),
338
"label_affine": tf.io.FixedLenFeature([16], tf.float32),
339
# Voxel Spacing (pixdim)
340
"flair_pixdim": tf.io.FixedLenFeature([8], tf.float32),
341
"t1_pixdim": tf.io.FixedLenFeature([8], tf.float32),
342
"t1ce_pixdim": tf.io.FixedLenFeature([8], tf.float32),
343
"t2_pixdim": tf.io.FixedLenFeature([8], tf.float32),
344
"label_pixdim": tf.io.FixedLenFeature([8], tf.float32),
345
# Filenames
346
"flair_filename": tf.io.FixedLenFeature([], tf.string),
347
"t1_filename": tf.io.FixedLenFeature([], tf.string),
348
"t1ce_filename": tf.io.FixedLenFeature([], tf.string),
349
"t2_filename": tf.io.FixedLenFeature([], tf.string),
350
"label_filename": tf.io.FixedLenFeature([], tf.string),
351
}
352
353
example = tf.io.parse_single_example(example_proto, feature_description)
354
355
# Decode image and label data
356
flair = tf.io.decode_raw(example["flair_raw"], tf.float32)
357
t1 = tf.io.decode_raw(example["t1_raw"], tf.float32)
358
t1ce = tf.io.decode_raw(example["t1ce_raw"], tf.float32)
359
t2 = tf.io.decode_raw(example["t2_raw"], tf.float32)
360
label = tf.io.decode_raw(example["label_raw"], tf.float32)
361
362
# Reshape to original dimensions
363
flair = tf.reshape(flair, example["flair_shape"])
364
t1 = tf.reshape(t1, example["t1_shape"])
365
t1ce = tf.reshape(t1ce, example["t1ce_shape"])
366
t2 = tf.reshape(t2, example["t2_shape"])
367
label = tf.reshape(label, example["label_shape"])
368
369
# Decode affine matrices
370
flair_affine = tf.reshape(example["flair_affine"], (4, 4))
371
t1_affine = tf.reshape(example["t1_affine"], (4, 4))
372
t1ce_affine = tf.reshape(example["t1ce_affine"], (4, 4))
373
t2_affine = tf.reshape(example["t2_affine"], (4, 4))
374
label_affine = tf.reshape(example["label_affine"], (4, 4))
375
376
# add channel axis
377
flair = flair[..., None]
378
t1 = t1[..., None]
379
t1ce = t1ce[..., None]
380
t2 = t2[..., None]
381
image = tf.concat([flair, t1, t1ce, t2], axis=-1)
382
383
return {
384
"image": image,
385
"label": label,
386
"affine": flair_affine, # Since affine is the same for all
387
}
388
389
390
"""
391
## Dataloader
392
"""
393
394
395
def train_dataloader(
396
tfrecord_datalist,
397
batch_size=1,
398
shuffle_buffer=100,
399
):
400
dataset = tf.data.TFRecordDataset(tfrecord_datalist)
401
dataset = dataset.shuffle(shuffle_buffer)
402
dataset = dataset.map(
403
parse_tfrecord_fn,
404
num_parallel_calls=tf.data.AUTOTUNE,
405
)
406
dataset = dataset.map(
407
rearrange_shape,
408
num_parallel_calls=tf.data.AUTOTUNE,
409
)
410
dataset = dataset.map(
411
train_transformation,
412
num_parallel_calls=tf.data.AUTOTUNE,
413
)
414
dataset = dataset.batch(
415
batch_size,
416
drop_remainder=True,
417
)
418
dataset = dataset.prefetch(tf.data.AUTOTUNE)
419
return dataset
420
421
422
def val_dataloader(
423
tfrecord_datalist,
424
batch_size=1,
425
):
426
dataset = tf.data.TFRecordDataset(tfrecord_datalist)
427
dataset = dataset.map(
428
parse_tfrecord_fn,
429
num_parallel_calls=tf.data.AUTOTUNE,
430
)
431
dataset = dataset.map(
432
rearrange_shape,
433
num_parallel_calls=tf.data.AUTOTUNE,
434
)
435
dataset = dataset.map(
436
val_transformation,
437
num_parallel_calls=tf.data.AUTOTUNE,
438
)
439
dataset = dataset.batch(batch_size)
440
dataset = dataset.prefetch(tf.data.AUTOTUNE)
441
return dataset
442
443
444
"""
445
The training batch size can be set to more than 1 depending on the environment and available resources. However, we intentionally keep the validation batch size as 1 to handle variable-sized samples more flexibly. While padded or ragged batches are alternative options, a batch size of 1 ensures simplicity and consistency during evaluation, especially for 3D medical data.
446
"""
447
448
tfrecord_pattern = "brats2020_subset/training_shard_*.tfrec"
449
datalist = sorted(
450
tf.io.gfile.glob(tfrecord_pattern),
451
key=lambda x: int(x.split("_")[-1].split(".")[0]),
452
)
453
454
train_datalist = datalist[:-1]
455
val_datalist = datalist[-1:]
456
print(len(train_datalist), len(val_datalist))
457
458
train_ds = train_dataloader(train_datalist, batch_size=1)
459
val_ds = val_dataloader(val_datalist, batch_size=1)
460
461
"""
462
**sanity check**: Fetch a single validation sample to inspect its shape and values.
463
"""
464
465
val_x, val_y = next(iter(val_ds))
466
test_image = val_x.numpy().squeeze()
467
test_mask = val_y.numpy().squeeze()
468
print(test_image.shape, test_mask.shape, np.unique(test_mask))
469
print(test_image.min(), test_image.max())
470
471
"""
472
**sanity check**: Visualize the middle slice of the image and its corresponding label.
473
"""
474
475
slice_no = test_image.shape[0] // 2
476
477
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
478
ax1.imshow(test_image[slice_no], cmap="gray")
479
ax1.set_title(f"Image shape: {test_image.shape}")
480
ax2.imshow(test_mask[slice_no])
481
ax2.set_title(f"Label shape: {test_mask.shape}")
482
plt.show()
483
484
"""
485
**sanity check**: Visualize sample image and label channels at middle slice index.
486
"""
487
488
print(f"image shape: {test_image.shape}")
489
plt.figure("image", (24, 6))
490
for i in range(4):
491
plt.subplot(1, 4, i + 1)
492
plt.title(f"image channel {i}")
493
plt.imshow(test_image[slice_no, :, :, i], cmap="gray")
494
plt.show()
495
496
497
print(f"label shape: {test_mask.shape}")
498
plt.figure("label", (18, 6))
499
for i in range(3):
500
plt.subplot(1, 3, i + 1)
501
plt.title(f"label channel {i}")
502
plt.imshow(test_mask[slice_no, :, :, i])
503
plt.show()
504
505
"""
506
## Model
507
508
We will be using the 3D model architecture Swin UNEt TRansformers, i.e., [`SwinUNETR`](https://arxiv.org/abs/2201.01266). It was used in the BraTS 2021 segmentation challenge by NVIDIA. The model was among the top-performing methods. It uses a Swin Transformer encoder to extract features at five different resolutions. A CNN-based decoder is connected to each resolution using skip connections.
509
510
The BraTS dataset provides four input modalities: `flair`, `t1`, `t1ce`, and `t2` and three multi-label outputs: `tumor-core`, `whole-tumor`, and `enhancing-tumor`. Accordingly, we will initiate the model with `4` input channels and `3` output channels.
511
512
![](https://i.imgur.com/OInMRGp.png)
513
"""
514
"""
515
```shell
516
# # check available models
517
# medicai.models.list_models()
518
```
519
"""
520
521
model = SwinUNETR(
522
encoder_name="swin_tiny_v2",
523
input_shape=input_shape,
524
num_classes=num_classes,
525
classifier_activation=None,
526
)
527
528
model.compile(
529
optimizer=keras.optimizers.AdamW(
530
learning_rate=1e-4,
531
weight_decay=1e-5,
532
),
533
loss=BinaryDiceCELoss(
534
from_logits=True,
535
num_classes=num_classes,
536
),
537
metrics=[
538
BinaryDiceMetric(
539
from_logits=True,
540
ignore_empty=True,
541
num_classes=num_classes,
542
name="dice",
543
),
544
BinaryDiceMetric(
545
from_logits=True,
546
ignore_empty=True,
547
target_class_ids=[0],
548
num_classes=num_classes,
549
name="dice_tc",
550
),
551
BinaryDiceMetric(
552
from_logits=True,
553
ignore_empty=True,
554
target_class_ids=[1],
555
num_classes=num_classes,
556
name="dice_wt",
557
),
558
BinaryDiceMetric(
559
from_logits=True,
560
ignore_empty=True,
561
target_class_ids=[2],
562
num_classes=num_classes,
563
name="dice_et",
564
),
565
],
566
)
567
568
# ALERT: This `instance_describe` attributes available in medicai.
569
try:
570
print(model.instance_describe())
571
except AttributeError:
572
pass
573
574
"""
575
## Callback
576
577
We will be using sliding window inference callback from `medicai` to perform validation at certain interval or epoch during training. Based on the number of epoch size, we should set `interval` accordingly. For example, if epoch is set 15 and we want to evaluate model on validation set every 5 epoch, then we should set `interval` to 5.
578
"""
579
580
swi_callback_metric = BinaryDiceMetric(
581
from_logits=True,
582
ignore_empty=True,
583
num_classes=num_classes,
584
name="val_dice",
585
)
586
587
swi_callback = SlidingWindowInferenceCallback(
588
model,
589
dataset=val_ds,
590
metrics=swi_callback_metric,
591
num_classes=num_classes,
592
interval=2,
593
overlap=0.5,
594
roi_size=input_shape[:3],
595
sw_batch_size=4,
596
mode="gaussian",
597
save_path="brats.model.weights.h5",
598
)
599
600
"""
601
## Training
602
603
Set more epoch for better optimization.
604
"""
605
606
history = model.fit(train_ds, epochs=epochs, callbacks=[swi_callback])
607
608
# Comment this line to keep the logs visible
609
clear_output()
610
611
"""
612
Let’s take a quick look at how our model performed during training. We will first print the available metrics recorded in the training history, save them to a CSV file for future reference, and then visualize them to better understand the model’s learning progress over epochs.
613
"""
614
615
616
def plot_training_history(history_df):
617
metrics = history_df.columns
618
n_metrics = len(metrics)
619
620
n_rows = 2
621
n_cols = (n_metrics + 1) // 2 # ceiling division for columns
622
623
plt.figure(figsize=(5 * n_cols, 5 * n_rows))
624
625
for idx, metric in enumerate(metrics):
626
plt.subplot(n_rows, n_cols, idx + 1)
627
plt.plot(history_df[metric], label=metric, marker="o")
628
plt.title(metric)
629
plt.xlabel("Epoch")
630
plt.ylabel("Value")
631
plt.grid(True)
632
plt.legend()
633
634
plt.tight_layout()
635
plt.show()
636
637
638
print(model.history.history.keys())
639
his_csv = pd.DataFrame(model.history.history)
640
his_csv.to_csv("brats.history.csv")
641
plot_training_history(his_csv)
642
643
"""
644
## Evaluation
645
646
In this [Kaggle notebook](https://www.kaggle.com/code/ipythonx/3d-brats-segmentation-in-keras-multi-gpu/) (version 5), we trained the model on the entire dataset for approximately `30` epochs. The resulting weights will be used for further evaluation. Note that the validation set used in both here and Kaggle notebook are the same: `training_shard_36.tfrec`, which contains `8` samples.
647
"""
648
649
model_weight = kagglehub.model_download(
650
"ipythonx/bratsmodel/keras/default", path="brats.model.weights.h5"
651
)
652
print("\nPath to model files:", model_weight)
653
654
model.load_weights(model_weight)
655
656
"""
657
In this section, we perform sliding window inference on the validation dataset and compute Dice scores for overall segmentation quality as well as specific tumor subregions:
658
- Tumor Core (TC)
659
- Whole Tumor (WT)
660
- Enhancing Tumor (ET)
661
"""
662
663
swi = SlidingWindowInference(
664
model,
665
num_classes=num_classes,
666
roi_size=input_shape[:3],
667
sw_batch_size=4,
668
overlap=0.5,
669
mode="gaussian",
670
)
671
672
dice = BinaryDiceMetric(
673
from_logits=True,
674
ignore_empty=True,
675
num_classes=num_classes,
676
name="dice",
677
)
678
dice_tc = BinaryDiceMetric(
679
from_logits=True,
680
ignore_empty=True,
681
target_class_ids=[0],
682
num_classes=num_classes,
683
name="dice_tc",
684
)
685
dice_wt = BinaryDiceMetric(
686
from_logits=True,
687
ignore_empty=True,
688
target_class_ids=[1],
689
num_classes=num_classes,
690
name="dice_wt",
691
)
692
dice_et = BinaryDiceMetric(
693
from_logits=True,
694
ignore_empty=True,
695
target_class_ids=[2],
696
num_classes=num_classes,
697
name="dice_et",
698
)
699
700
"""
701
Due to the variable size, and larger size of the validation data, we iterate over the validation dataloader. The sliding window inference handles input patches and computes the predictions for each batch.
702
"""
703
704
dice.reset_state()
705
dice_tc.reset_state()
706
dice_wt.reset_state()
707
dice_et.reset_state()
708
709
for sample in val_ds:
710
x, y = sample
711
output = swi(x)
712
dice.update_state(y, output)
713
dice_tc.update_state(y, output)
714
dice_wt.update_state(y, output)
715
dice_et.update_state(y, output)
716
717
dice_score = float(ops.convert_to_numpy(dice.result()))
718
dice_score_tc = float(ops.convert_to_numpy(dice_tc.result()))
719
dice_score_wt = float(ops.convert_to_numpy(dice_wt.result()))
720
dice_score_et = float(ops.convert_to_numpy(dice_et.result()))
721
722
# Comment this line to keep the logs visible
723
clear_output()
724
725
print(f"Dice Score: {dice_score:.4f}")
726
print(f"Dice Score on tumor core (TC): {dice_score_tc:.4f}")
727
print(f"Dice Score on whole tumor (WT): {dice_score_wt:.4f}")
728
print(f"Dice Score on enhancing tumor (ET): {dice_score_et:.4f}")
729
730
"""
731
## Analyse and Visualize
732
733
Let's analyse the model predictions and visualize them. First, we will implement the test transformation pipeline. This is same as validation transformation.
734
"""
735
736
737
def test_transformation(sample):
738
return val_transformation(sample)
739
740
741
"""
742
Let's load the `tfrecord` file and check its properties.
743
"""
744
745
index = 0
746
dataset = tf.data.TFRecordDataset(val_datalist[index])
747
dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
748
dataset = dataset.map(rearrange_shape, num_parallel_calls=tf.data.AUTOTUNE)
749
750
sample = next(iter(dataset))
751
orig_image = sample["image"]
752
orig_label = sample["label"]
753
print(orig_image.shape, orig_label.shape, np.unique(orig_label))
754
755
"""
756
Run the transformation to prepare the inputs.
757
"""
758
759
pre_image, pre_label = test_transformation(sample)
760
print(pre_image.shape, pre_label.shape)
761
762
"""
763
Pass the preprocessed sample to the inference object, ensuring that a batch axis is added to the input beforehand.
764
"""
765
766
y_pred = swi(pre_image[None, ...])
767
768
# Comment this line to keep the logs visible
769
clear_output()
770
771
print(y_pred.shape)
772
773
"""
774
After running inference, we remove the batch dimension and apply a `sigmoid` activation to obtain class probabilities. We then threshold the probabilities at `0.5` to generate the final binary segmentation map.
775
"""
776
777
y_pred_logits = y_pred.squeeze(axis=0)
778
y_pred_prob = ops.convert_to_numpy(ops.sigmoid(y_pred_logits))
779
segment = (y_pred_prob > 0.5).astype(int)
780
print(segment.shape, np.unique(segment))
781
782
783
"""
784
We compare the ground truth (`pre_label`) and the predicted segmentation (`segment`) for each tumor sub-region. Each sub-plot shows a specific channel corresponding to a tumor type: TC, WT, and ET. Here we visualize the `80th` axial slice across the three channels.
785
"""
786
787
label_map = {0: "TC", 1: "WT", 2: "ET"}
788
789
plt.figure(figsize=(16, 4))
790
for i in range(pre_label.shape[-1]):
791
plt.subplot(1, 3, i + 1)
792
plt.title(f"label channel {label_map[i]}")
793
plt.imshow(pre_label[80, :, :, i])
794
plt.show()
795
796
plt.figure(figsize=(16, 4))
797
for i in range(3):
798
plt.subplot(1, 3, i + 1)
799
plt.title(f"pred channel {label_map[i]}")
800
plt.imshow(segment[80, :, :, i])
801
plt.show()
802
803
804
"""
805
The predicted output is a multi-channel binary map, where each channel corresponds to a specific tumor region. To visualize it against the original ground truth, we convert it into a single-channel label map. Here we assign:
806
- Label 1 for Tumor Core (TC)
807
- Label 2 for Whole Tumor (WT)
808
- Label 4 for Enhancing Tumor (ET)
809
The label values are chosen to match typical conventions used in medical segmentation benchmarks like BraTS.
810
"""
811
812
prediction = np.zeros(
813
(segment.shape[0], segment.shape[1], segment.shape[2]), dtype="float32"
814
)
815
prediction[segment[..., 1] == 1] = 2
816
prediction[segment[..., 0] == 1] = 1
817
prediction[segment[..., 2] == 1] = 4
818
819
print("label ", orig_label.shape, np.unique(orig_label))
820
print("predicted ", prediction.shape, np.unique(prediction))
821
822
823
"""
824
Let's begin by examining the original input slices from the MRI scan. The input contains four channels corresponding to different MRI modalities:
825
- FLAIR
826
- T1
827
- T1CE (T1 with contrast enhancement)
828
- T2
829
We display the same slice number across all modalities for comparison.
830
"""
831
832
slice_map = {0: "flair", 1: "t1", 2: "t1ce", 3: "t2"}
833
slice_num = 75
834
835
plt.figure(figsize=(16, 4))
836
for i in range(orig_image.shape[-1]):
837
plt.subplot(1, 4, i + 1)
838
plt.title(f"Original channel: {slice_map[i]}")
839
plt.imshow(orig_image[slice_num, :, :, i], cmap="gray")
840
841
plt.tight_layout()
842
plt.show()
843
844
"""
845
Next, we compare this input with the ground truth label and the predicted segmentation on the same slice. This provides visual insight into how well the model has localized and segmented the tumor regions.
846
"""
847
848
num_channels = orig_image.shape[-1]
849
plt.figure("image", (15, 15))
850
851
# plotting image, label and prediction
852
plt.subplot(3, num_channels, num_channels + 1)
853
plt.title("image")
854
plt.imshow(orig_image[slice_num, :, :, 0], cmap="gray")
855
856
plt.subplot(3, num_channels, num_channels + 2)
857
plt.title("label")
858
plt.imshow(orig_label[slice_num, :, :])
859
860
plt.subplot(3, num_channels, num_channels + 3)
861
plt.title("prediction")
862
plt.imshow(prediction[slice_num, :, :])
863
864
plt.tight_layout()
865
plt.show()
866
867
"""
868
Finally, create a clean GIF visualizer showing the input image, ground-truth label, and model prediction.
869
"""
870
# The input volume contains large black margins, so we crop
871
# the foreground region of interest (ROI).
872
crop_foreground = CropForeground(
873
keys=("image", "label", "prediction"), source_key="image"
874
)
875
876
data = {
877
"image": orig_image,
878
"label": orig_label[..., None],
879
"prediction": prediction[..., None],
880
}
881
results = crop_foreground(data)
882
crop_orig_image = results["image"]
883
crop_orig_label = results["label"]
884
crop_prediction = results["prediction"]
885
886
"""
887
Prepare a visualization-friendly prediction map by remapping label values to a compact index range.
888
"""
889
890
viz_pred = np.zeros_like(crop_prediction, dtype="uint8")
891
viz_pred[crop_prediction == 1] = 1
892
viz_pred[crop_prediction == 2] = 2
893
viz_pred[crop_prediction == 4] = 3
894
895
# Colormap for background, tumor core, edema, and enhancing regions
896
cmap = ListedColormap(
897
[
898
"#000000", # background
899
"#E57373", # muted red
900
"#64B5F6", # muted blue
901
"#81C784", # muted green
902
]
903
)
904
905
# Create side-by-side views for input, label, and prediction
906
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
907
ax_img, ax_lbl, ax_pred = axes
908
909
img_im = ax_img.imshow(crop_orig_image[0, :, :, 0], cmap="gray")
910
lbl_im = ax_lbl.imshow(
911
crop_orig_label[0], vmin=0, vmax=3, cmap=cmap, interpolation="nearest"
912
)
913
pred_im = ax_pred.imshow(
914
viz_pred[0], vmin=0, vmax=3, cmap=cmap, interpolation="nearest"
915
)
916
917
# Tight layout for a compact GIF
918
plt.subplots_adjust(left=0.01, right=0.99, bottom=0.02, top=0.8, wspace=0.01)
919
920
for ax, t in zip(axes, ["FLAIR", "Label", "Prediction"]):
921
ax.set_title(t, fontsize=19, pad=10)
922
ax.axis("off")
923
ax.set_adjustable("box")
924
925
926
def update(i):
927
img_im.set_data(crop_orig_image[i, :, :, 0])
928
lbl_im.set_data(crop_orig_label[i])
929
pred_im.set_data(viz_pred[i])
930
fig.suptitle(f"Slice {i}", fontsize=14)
931
return img_im, lbl_im, pred_im
932
933
934
ani = animation.FuncAnimation(
935
fig, update, frames=crop_orig_image.shape[0], interval=120
936
)
937
ani.save(
938
"segmentation_slices.gif",
939
writer="pillow",
940
dpi=100,
941
)
942
plt.close(fig)
943
944
"""
945
When you open the saved GIF, you should see a visualization similar to this.
946
947
![Animation of the brain tumor segmentation results](https://i.imgur.com/CbaQGf2.gif)
948
"""
949
950
"""
951
## Additional Resources
952
953
- [BraTS Segmentation on Multi-GPU](https://www.kaggle.com/code/ipythonx/3d-brats-segmentation-in-keras-multi-gpu)
954
- [BraTS Segmentation on TPU-VM](https://www.kaggle.com/code/ipythonx/3d-brats-segmentation-in-keras-tpu-vm)
955
- [BraTS .nii to TFRecord](https://www.kaggle.com/code/ipythonx/brats-nii-to-tfrecord)
956
- [Covid-19 Segmentation](https://www.kaggle.com/code/ipythonx/medicai-covid-19-3d-image-segmentation)
957
- [3D Multi-organ Segmentation](https://www.kaggle.com/code/ipythonx/medicai-3d-btcv-segmentation-in-keras)
958
- [Spleen 3D segmentation](https://www.kaggle.com/code/ipythonx/medicai-spleen-3d-segmentation-in-keras)
959
- [3D Medical Image Transformation](https://www.kaggle.com/code/ipythonx/medicai-3d-medical-image-transformation)
960
"""
961
962