Path: blob/master/notebooks/book1/14/layer_norm_torch.ipynb
1192 views
Kernel: Python 3
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/14/layer_norm_jax.ipynb
In [1]:
In [2]:
Out[2]:
batch norm
[[-1. -1. -1.]
[ 1. 1. 1.]]
layer norm
[[-1.22474487 0. 1.22474487]
[-1.22474487 0. 1.22474487]]
In [4]:
Out[4]:
batch norm
tensor([[-1.0000, -1.0000, -1.0000],
[ 1.0000, 1.0000, 1.0000]], grad_fn=<NativeBatchNormBackward>)
layer norm
tensor([[-1.2247e+00, 0.0000e+00, 1.2247e+00],
[-1.2247e+00, 1.1921e-07, 1.2247e+00]],
grad_fn=<NativeLayerNormBackward>)
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
In [ ]: