Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/14/conv2d_jax.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Open In Colab

Foundations of Convolutional neural nets

Based on sec 6.2 of http://d2l.ai/chapter_convolutional-neural-networks/conv-layer.html

import jax import jax.numpy as jnp try: import flax.linen as nn except ModuleNotFoundError: %pip install -qq flax import flax.linen as nn from typing import Tuple !mkdir figures # for saving plots import warnings warnings.filterwarnings("ignore") key = jax.random.PRNGKey(1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Cross correlation

# Cross correlation def corr2d(X: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray: """Compute 2D cross-correlation.""" h, w = K.shape Y = jnp.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): Y = Y.at[i, j].set(jnp.sum(X[i : i + h, j : j + w] * K)) return Y X = jnp.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]) K = jnp.array([[0.0, 1.0], [2.0, 3.0]]) print(corr2d(X, K))
[[19. 25.] [37. 43.]]

Edge detection

We make a small image X of 1s, with a vertical stripe (of width 4) of 0s in the middle.

X = jnp.ones((6, 8)) X = X.at[:, 2:6].set(0) X
DeviceArray([[1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.]], dtype=float32)

Now we apply a vertical edge detector. It fires on the 1-0 and 0-1 boundaries.

K = jnp.array([[1.0, -1.0]]) Y = corr2d(X, K) print(Y)
[[ 0. 1. 0. 0. 0. -1. 0.] [ 0. 1. 0. 0. 0. -1. 0.] [ 0. 1. 0. 0. 0. -1. 0.] [ 0. 1. 0. 0. 0. -1. 0.] [ 0. 1. 0. 0. 0. -1. 0.] [ 0. 1. 0. 0. 0. -1. 0.]]

It fails to detect horizontal edges.

corr2d(X.T, K)
DeviceArray([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], dtype=float32)

Convolution as matrix multiplication

K = jnp.array([[1, 2], [3, 4]]) print(K) def kernel2matrix(K: jnp.ndarray) -> jnp.ndarray: k, W = jnp.zeros(5), jnp.zeros((4, 9)) k = k.at[:2].set(K[0, :]) k = k.at[3:5].set(K[1, :]) W = W.at[0, :5].set(k) W = W.at[1, 1:6].set(k) W = W.at[2, 3:8].set(k) W = W.at[3, 4:].set(k) return W W = kernel2matrix(K) print(W)
[[1 2] [3 4]] [[1. 2. 0. 3. 4. 0. 0. 0. 0.] [0. 1. 2. 0. 3. 4. 0. 0. 0.] [0. 0. 0. 1. 2. 0. 3. 4. 0.] [0. 0. 0. 0. 1. 2. 0. 3. 4.]]
X = jnp.arange(9.0).reshape(3, 3) Y = corr2d(X, K) print(Y) Y2 = jnp.dot(W, X.reshape(-1)).reshape(2, 2) assert jnp.allclose(Y, Y2)
[[27. 37.] [57. 67.]]

Optimizing the kernel parameters

Let's learn a kernel to match the output of our manual edge detector.

# Construct a two-dimensional convolutional layer with 1 output channel and a # kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here conv2d = nn.Conv(1, kernel_size=(1, 2), padding=((0, 0), (0, 0)), use_bias=False) # The two-dimensional convolutional layer uses four-dimensional input and # output in the format of (example channel, height, width), where the batch # size (number of examples in the batch) and the number of channels are both 1 # Defining X and Y again. X = jnp.ones((6, 8)) X = X.at[:, 2:6].set(0) K = jnp.array([[1.0, -1.0]]) Y = corr2d(X, K) print(Y.shape) X = jnp.transpose(X.reshape((1, 1, 6, 8)), (0, 2, 3, 1)) Y = jnp.transpose(Y.reshape((1, 1, 6, 7)), (0, 2, 3, 1)) params = conv2d.init(key, X) @jax.jit def step(params, X, Y): return jnp.sum((conv2d.apply(params, X) - Y) ** 2) for i in range(10): l, grads = jax.value_and_grad(step)(params, X, Y) params = jax.tree_map(lambda p, g: p - 3e-2 * g, params, grads) if (i + 1) % 2 == 0: print(f"batch {i + 1}, loss {l:.3f}") print(params["params"]["kernel"].reshape((1, 2)))
(6, 7) batch 2, loss 19.368 batch 4, loss 5.172 batch 6, loss 1.655 batch 8, loss 0.600 batch 10, loss 0.233 [[ 1.0312853 -0.93357116]]

Multiple input channels

def corr2d(X: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray: """Compute 2D cross-correlation.""" h, w = K.shape Y = jnp.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): Y = Y.at[i, j].set(jnp.sum((X[i : i + h, j : j + w] * K))) return Y
def corr2d_multi_in(X: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray: # First, iterate through the 0th dimension (channel dimension) of `X` and # `K`. Then, add them together return sum(corr2d(x, k) for x, k in zip(X, K)) X = jnp.array( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], ] ) K = jnp.array([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]]) print(X.shape) # 2 channels, each 3x3 print(K.shape) # 2 sets of 2x2 filters out = corr2d_multi_in(X, K) print(out.shape) print(out)
(2, 3, 3) (2, 2, 2) (2, 2) [[ 56. 72.] [104. 120.]]

Multiple output channels

def corr2d_multi_in_out(X: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray: # Iterate through the 0th dimension of `K`, and each time, perform # cross-correlation operations with input `X`. All of the results are # stacked together return jnp.stack([corr2d_multi_in(X, k) for k in K], 0) K = jnp.stack((K, K + 1, K + 2), 0) print(K.shape) out = corr2d_multi_in_out(X, K) print(out.shape)
(3, 2, 2, 2) (3, 2, 2)
# 1x1 conv is same as multiplying each feature column at each pixel # by a fully connected matrix def corr2d_multi_in_out_1x1(X: jnp.ndarray, K: jnp.ndarray) -> jnp.ndarray: c_i, h, w = X.shape c_o = K.shape[0] X = X.reshape((c_i, h * w)) K = K.reshape((c_o, c_i)) Y = jnp.matmul(K, X) # Matrix multiplication in the fully-connected layer return Y.reshape((c_o, h, w)) X = jax.random.truncated_normal(key, 0, 1, (3, 3, 3)) # 3 channels per pixel K = jax.random.truncated_normal(key, 0, 1, (2, 3, 1, 1)) # map from 3 channels to 2 Y1 = corr2d_multi_in_out_1x1(X, K) Y2 = corr2d_multi_in_out(X, K) print(Y2.shape) assert float(jnp.abs(Y1 - Y2).sum()) < 1e-6
(2, 3, 3)

Pooling

def pool2d(X: jnp.ndarray, pool_size: Tuple[int], mode: str = "max") -> jnp.ndarray: p_h, p_w = pool_size Y = jnp.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1)) for i in range(Y.shape[0]): for j in range(Y.shape[1]): if mode == "max": Y = Y.at[i, j].set(X[i : i + p_h, j : j + p_w].max()) elif mode == "avg": Y = Y.at[i, j].set(X[i : i + p_h, j : j + p_w].mean()) return Y
X = jnp.arange(16).reshape((4, 4)) print(X) print(X.shape) print(pool2d(X, (3, 3), "max"))
[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] (4, 4) [[10. 11.] [14. 15.]]
X = jnp.arange(16).reshape((1, 4, 4, 1)) y = nn.max_pool(X, (3, 3), strides=(1, 1)) print(y)
[[[[10] [11]] [[14] [15]]]]