Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
davidADSP
GitHub Repository: davidADSP/Generative_Deep_Learning_2nd_Edition
Path: blob/main/notebooks/utils.py
960 views
1
import matplotlib.pyplot as plt
2
3
4
def sample_batch(dataset):
5
batch = dataset.take(1).get_single_element()
6
if isinstance(batch, tuple):
7
batch = batch[0]
8
return batch.numpy()
9
10
11
def display(
12
images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None
13
):
14
"""
15
Displays n random images from each one of the supplied arrays.
16
"""
17
if images.max() > 1.0:
18
images = images / 255.0
19
elif images.min() < 0.0:
20
images = (images + 1.0) / 2.0
21
22
plt.figure(figsize=size)
23
for i in range(n):
24
_ = plt.subplot(1, n, i + 1)
25
plt.imshow(images[i].astype(as_type), cmap=cmap)
26
plt.axis("off")
27
28
if save_to:
29
plt.savefig(save_to)
30
print(f"\nSaved to {save_to}")
31
32
plt.show()
33
34