Einsums in the wild
By Gerardo Durán-Martín (@grrddm)
Adapted from grrdm.notion.so
Do you know what
inm,kij,jnm->knm
is all about?
Introduction
Linear combinations are ubiquitous in machine learning and statistics. Many algorithms and models in the statistics and machine learning literature can be written (or approximated) as a matrix-vector multiplication. Einsums are a way of representing the linear interaction among vectors, matrices and higher-order dimensional arrays.
In this post, I lay out examples that make use of Einsums. I assume that the reader is familiar with the basics of einsums. However, I provide a quick introduction in the next section. For furthere references, see also [1] and [2]. Throughout this post, we borrow from the numpy literature and denote the element as an -dimensional array.
In the next section, we present a brief summary of einsum expressions and its usage in numpy / jax.numpy.
An quick introduction to einsums: from sums to indices
Let a 1-dimensional array. We denote by the -th element of . Suppose we want to express the sum over all elements in . This can be written as
To introduce the einsum notation, we notice that the sum symbol () in this equation simply states that we should consider all elements of and sum them. If we assume that 1) there is no ambiguity on the number of dimensions in and 2) we sum over all of its elements, we define the einsum notation for the sum over all elements in the 1-dimensional array as
To keep our notation consistent, we denote indices with parenthesis as static dimensions. Static dimensions allows us to expand the expressiveness power of einsums. That is, we denote all of the elements of under the einsum notation as .
Since the name of the arrays are not necessarily meaningful to define these expressions, we define einsum expressions in numpy by focusing only on the indices. To represent which dimensions are static and which should be summed over, we introduce the ->
notation. Elements to the left of ->
define the set of indices of an array and elements to the right of ->
represent indices that we do not sum over. For example, the sum over all elements in is written as
and the selection of all elements of is written as
In the following snippet, we show this notation in action.
Higher-dimensional arrays
Let and be two one-dimensional arrays. The dot product between and can be written as
Following our previous notation, we see that the representation of this einsum expression in mathematical and numpy form is
Furthermore, the einsum representation of the element-wise product between and is given by
As an example, consider the following 1-dimensional arrays a
and b
We can generalise the ideas previously presented. Consider the matrix-vector multiplication between and . We can write this in linear algebra form as
Where we have denoted as the -th row of . From the equation above, we notice that the -th entry of can be expressed as
We also observe that the first dimension of for this expression is is static. The einsum representation in mathematical/numpy form becomes
Considering the result of the last example, we can easily express the resulting -th entry of the multiplication between two matrices. Let , , the product between and becomes
Then, the -th entry of the matrix-matrix multiplication can be expressed as
From the equation above, we see that the first dimension of and the second dimension of are static. We represent its einsum form as
Even-higher-dimensional arrays
The advantage of using einsums in machine learning is their expressive power when working with higher-dimensional arrays. As we will see, knowing the einsum representation of a matrix-vector multiplication operation easily allows us to generalise it for multiple dimensions. This is because ensums can be thought of as expressions of linear transformations when static dimensions are present in the output.
To motivate the use of of expressing linear combinations as einsums expressions in machine learning, we consider the following example.
Einsums in machine learning
Let and be one-dimensional and two-dimensional arrays respectively. The squared Mahalanobis distance centred at zero with precision matrix is defined as
Using the typical rules for matrix-vector multiplication, we evaluate for any given and a valid precision matrix . We readily can evaluate as an einsum expression as i,ij,j->
. This is because
A more interesting scenario is to consider the case where we have observations stored in a 2-dimensional array . If we denote by the -th observation in , to compute the squared Mahalanobis distance for each observation means to obtain
One such a way to obtain a collection of squared Mahalanobis distances evaluated at each of the elements in is to compute
where . To see why, note that
So that
The computation of the above expression is inefficient since we need to compute terms to obtain our desired 1-dimensional array of size . A much efficient way to compute and express the set of squared Mahalanobis distances is to make use of einsums. As we’ve seen, is written in einsum form as
The extension of the latter expression to a set of elements is straightforward by noting that we only need to specify that the first dimension of is static. We obtain
In this particular example, using einsums helps us avoid computing the terms not in the diagonal, which increases the speed at which we can compute this expression compared to the traditional method
To generalise this result, consider the 3-dimensional array . The algebraic representation of the squared Mahalanobis distance evaluated over the last dimension of is not possible using the basic tools of linear algebra (we will see a use case of this when we show how to plot the log-density of a 2-dimensional Gaussian distribution). From our previous result, we see that to expand this computation for , we only need to introduce an additional static index to obtain:
What these last expressions show is that einsums can be of great help in scenarios when we have to compute a known linear transformation over unused indices.
If we continue with the process of increasing the dimension of , we obtain the following results
i,ij,j->
for scalar output,ni,ij,nj->n
for a 1-dimensional array output,nmi,ij,nmj->nm
for a 2-dimensional array (grid) output,nmli,ij,nmlj->nml
for a 3-dimensional array output, and...i,ij,...j->...
for a -dimensional array output.
Furthermore, einsums expressions are commutative over block of indices. This means that the result of the einsum expression is independent of the order in which arrays are positioned. For our previous example, the following three expressions are equivalent:
Log-density of a Gaussian
Let , the log-density of is given by
Suppose we want to plot the log-density of a bivariate Gaussian distribution up to a normalisation constant over a region . As we have previously seen, the expression can be represented in einsum form as i,ij,j->
. By introducing static dimensions n
and m
, we compute the log-density over by adding the n
and m
indices in the einsum expression and specifying them as the final result. We obtain nmi,ij,nmj->nm
. A common way to obtain the grid in python is through jnp.mgrid
. We present an example of this below.
We expand the previous idea to the case of a set of multivariate Gaussians with constant mean and multiple covariance matrices.
Recall that the einsum expression to compute the log-density of a bivariate normal over a region is given by inm,ij,jnm->nm
. Assuming that we have a set of Gaussian distributions. For each index , we have a precision matrix and constant mean . To compute the density over each of the regions we simply modify our previous expression to take account of a new static dimension k
. We obtain inm,kij,jnm->knm
.
As an example, consider the collection of four covariance matrices C1
,C2
,C3
,C4
. We show that einsums can be used to compute the log-density over the multiple Gaussians
To recap, the einsum expression for the Mahalanobis distance distance evaluated at is given by
i,ij,ij->
for a single array (a vector),ni,ij,nj->n
for a collection of -dimensional arrays (a matrix of observations),nmi,ij,nmj->nm
for a grid of of -dimensional arrays,nmi,kij,nmj->knm
for a grid of dimensional arrays evaluated over different precision matrices.
Predictive surface of a Bayesian logistic regression model
As long as the inner-most operation we want to compute consists of a linear combination of elements we can make use of einsums. As a next example, consider the estimation of the predictive surface of a logistic regression with Gaussian prior. That is, we want to compute
Suppose we estimated the posterior parameters . Since the posterior predictive distribution is analytically intractable, we turn to a Monte Carlo approximation of the posterior predictive surface. As with the previous two examples, we want to compute over a surface . In this scenario, we have samples of weights sampled from the posterior which we wish to evaluate over all points in the grid . Recalling that the dot product between two vectors is written in einsum form as m,m->
, to obtain a 3-dimensional array comprising of simulations evaluated at each point in , we simply expand the dot product expression to contain the static indices s
for the simulation and i,j
for the position in the grid. We obtain sm,mij->sij
. After obtaining the grid sij
, we can compute the approximated predictive distribution by applying the logistic function over each element and averaging over s
. The following code shows how to achieve this.
Image compression: Singular value decomposition
A typical example that one encounters learning about singular value decomposition (SVD) is the use of SVD to decompress an image. As a heuristic example, suppose we want to compare the SVD of an image over multiple thresholds. That is, we decompose an image as
It’s a classical result of linear algebra that our matrix can be factorised as
where , , and is a matrix with diagonal terms and zero everywhere else. In scipy, the SVD decomposition of is conviniently factorised (in einsum form) as
where , , , and .
As a pedagogical example, suppose we wish to approximate the matrix using the first singular components. First, we observe that the ()-th entry of is given by
If we wish to consider the first components of , we only need to modify the limit term in the sum to obtain.
However, this last expression cannot be represented in einsum notation. As me mentioned at the beginning, every einsum expression assumes that the sum is over all chosen indices. To get around this constraint, we simply introduce the 1-dimensional vector of size that has value for the first entries and for the rest elements. Hence, we write the approximation of the matrix using the first singular components as
We observe that this is easily written in einsum form as
We could also consider multiple values of . To do this, we define the 2-dimensional array
Next, we simply modify our previous expression to take into account the additional static dimension of the matrix . We obtain
We provide an example of this idea in the next code: first, we load an image living in a 3-dimensional array. Next, we transform it img
to obtain a 2-dimensional array img_gray
. We perform SVD over img_gray
and define a matrix indexer
containing our different thresholds. Finally, we make use of our previously-defined expression to compute the SVD approximation of the image at the different values defined in indexer
.
Misc examples
Computing the state-value and action-value function in a tabular space For an example, see this notebook
Diagonal extended Kalman filter (dEKF) For an example, seethis script