Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/14/layer_norm_jax.ipynb
1192 views
Kernel: Python 3
import numpy as np import jax import jax.numpy as jnp try: from flax import linen as nn except ModuleNotFoundError: %pip install -qq flax from flax import linen as nn
# batch size 3, feature size 2 np.random.seed(42) X = np.random.normal(size=(2, 3)) print("batch norm") mu_batch = np.mean(X, axis=0) sigma_batch = np.std(X, axis=0) XBN = (X - mu_batch) / sigma_batch print(XBN) print("layer norm") mu_layer = np.expand_dims(np.mean(X, axis=1), axis=1) sigma_layer = np.expand_dims(np.std(X, axis=1), axis=1) XLN = (X - mu_layer) / sigma_layer print(XLN)
batch norm [[-1. 1. 1.] [ 1. -1. -1.]] layer norm [[ 0.47376014 -1.39085732 0.91709718] [ 1.41421356 -0.70711669 -0.70709687]]
X = jnp.float32(X) rng = jax.random.PRNGKey(42) bn_rng, ln_rng = jax.random.split(rng) print("batch norm") bn = nn.BatchNorm(use_running_average=False, epsilon=1e-6) bn_params = bn.init(bn_rng, X) XBN_t, _ = bn.apply(bn_params, X, mutable=["batch_stats"]) print(XBN_t) assert np.allclose(np.array(XBN_t), XBN, atol=1e-3) print("layer norm") ln = nn.LayerNorm() ln_params = ln.init(ln_rng, X) XLN_t = ln.apply(ln_params, X) print(XLN_t) assert np.allclose(np.array(XLN_t), XLN, atol=1e-3)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
batch norm [[-0.99999815 0.99978346 0.99999744] [ 0.99999815 -0.9997831 -0.9999975 ]] layer norm [[ 0.473758 -1.3908514 0.9170933 ] [ 1.4142125 -0.70711625 -0.7070964 ]]