Path: blob/main/notebooks/utils.py
960 views
import matplotlib.pyplot as plt123def sample_batch(dataset):4batch = dataset.take(1).get_single_element()5if isinstance(batch, tuple):6batch = batch[0]7return batch.numpy()8910def display(11images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None12):13"""14Displays n random images from each one of the supplied arrays.15"""16if images.max() > 1.0:17images = images / 255.018elif images.min() < 0.0:19images = (images + 1.0) / 2.02021plt.figure(figsize=size)22for i in range(n):23_ = plt.subplot(1, n, i + 1)24plt.imshow(images[i].astype(as_type), cmap=cmap)25plt.axis("off")2627if save_to:28plt.savefig(save_to)29print(f"\nSaved to {save_to}")3031plt.show()323334