Introduction to Elegy
This is slightly modified from https://poets-ai.github.io/elegy/getting-started/high-level-api/ and https://poets-ai.github.io/elegy/getting-started/low-level-api/
In this tutorial we will explore the basic features of Elegy. If you are a Keras user you should feel at home, if you are currently learning JAX things will appear much more streamlined. To get started you will first need to install the following dependencies:
Loading the Data
In this tutorial we will train a Neural Network on the MNIST dataset, for this we will first need to download and load the data into memory. Here we will use the datasets
library to load the dataset.
Defining the Architecture
The first thing we need to do is define our model's architecture inside a Module
, to do this we just create a class that inherites from Module
and implement a __call__
method. In this example we will create a simple 2 layer MLP:
This code should feel familiar to most Keras / PyTorch users, the main difference is that we are using the @compact
decorator to define submodules inline (e.g. Linear) inline, this tends to produce much shorter and readable code.
Creating the Model
Now that we have this module we can create an Elegy Model
which is Elegy's central API:
If you are a Keras user this code should look familiar, main differences are:
You need to pass a
module
with the architecture.loss
andmetrics
are a bit more flexible in that they do not need to match the label's structure.There is no
compile
step, all its done in the constructor.For the
optimizer
you can use anyoptax
optimizer.
As in Keras, you can get a rich description of the model by calling Model.summary
with a sample input:
Training the Model
We are now ready to pass our model some data to start training, like in Keras this is done via the fit
method which contains more or less the same signature. Elegy support a variety of input data sources like Tensorflow Dataset, Pytorch DataLoader, Elegy DataLoader, and Python Generators, check out the guide on Data Sources for more information.
The following code will train our model for 10
epochs while limiting each epoch to 200
steps and using a batch size of 64
:
The ModelCheckpoint
callback will periodically save the model in a folder called "models/high-level"
, we will use it later.
fit
returns a History
object which of the losses and metrics during training which we can visualize.
Plotting learning curves
Generating Predictions
Having our trained model we can now get some samples from the test set and generate some predictions. Lets select 9
random images and call .predict
:
Easy right? Finally lets plot the results to see if they are accurate.
Serialization
To serialize the Model
you can use the model.save(...)
, this will create a folder with some files that contain the model's code plus all parameters and states.
However since we had previously used the ModelCheckpoint
callback we can load it using elegy.load
. Lets get a new model reference containing the same weights and call its evaluate
method to verify it loaded correctly:
You can also serialize your Elegy Model as a TensorFlow SavedModel which is portable to many platforms many platforms and services, to do this you can use the saved_model
method. saved_model
will convert the function that creates the predictions for your Model (pred_step
) in Jax to a TensorFlow version via jax2tf
and then serialize it to disk.
The function saved_model
accepts a sample to infer the shapes, the path where the model will be saved at, and a list of batch sizes for the different signatures it accepts. Due to some current limitations in Jax it is not possible to create signatures with dynamic dimensions so you must specify a couple which might fit you needs.
We can test our saved model by loading it with TensorFlow and generating a couple of predictions as we did previously:
---------------------------------------------------------------------------
ResourceExhaustedError Traceback (most recent call last)
<ipython-input-28-d29d35341090> in <module>()
3 saved_model = tf.saved_model.load("saved-models/high-level")
4
----> 5 y_pred_tf = saved_model(x_sample.astype(np.int32))
6
7 plt.figure(figsize=(12, 12))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
699
700 def _call_attribute(instance, *args, **kwargs):
--> 701 return instance.__call__(*args, **kwargs)
702
703
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 ctx.ensure_initialized()
58 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 59 inputs, attrs, num_outputs)
60 except core._NotOkStatusException as e:
61 if name is not None:
ResourceExhaustedError: Out of memory while trying to allocate 16788016 bytes.
[[{{function_node __inference_<lambda>_2495}}{{node XlaDotV2}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
[Op:__inference_restored_function_body_2791]
Distributed training
To parallelize training and inference using pmap on a mulit-core TPU you just need to add
after creating the model. For an example, try running https://github.com/probml/pyprobml/blob/master/scripts/mnist_elegy_distributed.py on a TPU VM v3-8. In colab, there will not be any speedup, since there is only 1 GPU. (I have not tried TPU mode in colab.)
Low-level API
Introduction
The low-level API lets you redefine what happens during the various stages of training, evaluation and inference by implementing some methods in a custom class. Here is the list of methods you can define along with the high-level method that uses it:
| Low-level Method | High-level Method | | :- | :- | | pred_step
| predict
| | test_step
| evaluate
| | grad_step
| NA | | train_step
| fit
|
Check out the guides on the low-level API for more information.
In this tutorial we are going to implement Linear Classifier using pure Jax by overridingpred_step
which defines the forward pass and test_step
which defines loss and metrics of our model.
pred_step
returns a tuple with:
y_pred
: predictions of the modelstates
: aelegy.States
namedtuple that contains the states for thing like network trainable parameter, network states, metrics states, optimizer states, rng state.
test_step
returns a tuple with:
loss
: the scalar loss use to calculate the gradientlogs
: a dictionary with the logs to be reported during trainingstates
: aelegy.States
namedtuple that contains the states for thing like network trainable parameter, network states, metrics states, optimizer states, rng state.
Since Jax is functional you will find that low-level API is very explicit with state management, that is, you always get the currrent state as input and you return the new state as output. Lets define test_step
to make things clearer:
Linear classifier
Notice the following:
We define a bunch of arguments with specific names, Elegy uses Dependency Injection so you can just request what you need.
initializing
tells us if we should initialize our parameters or not, here we are directly creating them ourselves but if you use a Module system you can conditionally call itsinit
method here.Our model is defined by a simple linear function.
Defined a simple crossentropy loss and an accuracy metric, we added both the the logs.
We set the updated
States.net_params
with thew
andb
parameters so we get them as an input on the next run after they are initialized.States.update
offers a clean way inmutably update the states without having to copy all fields to a new States structure.
Remember test_step
only defines what happens during evaluate
, however, Model
's default implementation has a structure where on method is defined in terms of another:
Because of this, we get the train_step
/ fit
for free if we just pass an optimizer to the the constructor as we are going to do next:
Training
Notice that the logs are very noisy, this is because for this example we didn't use cummulative metrics so the reported value is just the value for the last batch of that epoch, not the value for the entire epoch. To fix this we could use some of the modules in elegy.metrics
.