Path: blob/master/3 - Natural Language Processing with Sequence Models/Week 1/C3W1_L1_Introduction to Trax.ipynb
65 views
Trax : Ungraded Lecture Notebook
In this notebook you'll get to know about the Trax framework and learn about some of its basic building blocks.
Background
Why Trax and not TensorFlow or PyTorch?
TensorFlow and PyTorch are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.
Trax is much more concise. It runs on a TensorFlow backend but allows you to train models with 1 line commands. Trax also runs end to end, allowing you to get data, model and train all with a single terse statements. This means you can focus on learning, instead of spending hours on the idiosyncrasies of big framework implementation.
Why not Keras then?
Keras is now part of Tensorflow itself from 2.0 onwards. Also, trax is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs,GPUs and TPUs as well with comparatively lesser modifications in code.
How to Code in Trax
Building models in Trax relies on 2 key concepts:- layers and combinators. Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing you to build layers and models of any complexity.
Trax, JAX, TensorFlow and Tensor2Tensor
You already know that Trax uses Tensorflow as a backend, but it also uses the JAX library to speed up computation too. You can view JAX as an enhanced and optimized version of numpy.
Watch out for assignments which import import trax.fastmath.numpy as np. If you see this line, remember that when calling np you are really calling Trax’s version of numpy that is compatible with JAX.
As a result of this, where you used to encounter the type numpy.ndarray now you will find the type jax.interpreters.xla.DeviceArray.
Tensor2Tensor is another name you might have heard. It started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So you can view Trax as the new improved version that operates much faster and simpler.
Resources
Installing Trax
Trax has dependencies on JAX and some libraries like JAX which are yet to be supported in Windows but work well in Ubuntu and MacOS. We would suggest that if you are working on Windows, try to install Trax on WSL2.
Official maintained documentation - trax-ml not to be confused with this TraX
Imports
trax 1.3.1
WARNING: You are using pip version 20.1.1; however, version 20.2.1 is available.
You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.
Layers
Layers are the core building blocks in Trax or as mentioned in the lectures, they are the base classes.
They take inputs, compute functions/custom calculations and return outputs.
You can also inspect layer properties. Let me show you some examples.
Relu Layer
First I'll show you how to build a relu activation function as a layer. A layer like this is one of the simplest types. Notice there is no object initialization so it works just like a math function.
Note: Activation functions are also layers in Trax, which might look odd if you have been using other frameworks for a longer time.
Concatenate Layer
Now I'll show you how to build a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2.
Layers are Configurable
You can change the default settings of layers. For example, you can change the expected inputs for a concatenate layer from 2 to 3 using the optional parameter n_items.
Note: At any point,if you want to refer the function help/ look up the documentation or use help function.
Layers can have Weights
Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use.
For example the LayerNorm layer calculates normalized data, that is also scaled by weights and biases. During initialization you pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases.
Custom Layers
This is where things start getting more interesting! You can create your own custom layers too and define custom functions for computations by using tl.Fn. Let me show you how.
Combinators
You can combine layers to build more complex layers. Trax provides a set of objects named combinator layers to make this happen. Combinators are themselves layers, so behavior commutes.
Serial Combinator
This is the most common and easiest to use. For example could build a simple neural network by combining layers into a single layer using the Serial combinator. This new layer then acts just like a single layer, so you can inspect intputs, outputs and weights. Or even combine it into another layer! Combinators can then be used as trainable models. Try adding more layers
Note:As you must have guessed, if there is serial combinator, there must be a parallel combinator as well. Do try to explore about combinators and other layers from the trax documentation and look at the repo to understand how these layers are written.
JAX
Just remember to lookout for which numpy you are using, the regular ol' numpy or Trax's JAX compatible numpy. Both tend to use the alias np so watch those import blocks.
Note:There are certain things which are still not possible in fastmath.numpy which can be done in numpy so you will see in assignments we will switch between them to get our work done.
Summary
Trax is a concise framework, built on TensorFlow, for end to end machine learning. The key building blocks are layers and combinators. This notebook is just a taste, but sets you up with some key inuitions to take forward into the rest of the course and assignments where you will build end to end models.