NumPy+Jax with named axes and an uncompromising attitude
Does this resonate with you?
-
In NumPy (and PyTorch and Jax et al.), broadcasting and batching and indexing are confusing and tedious.
-
Einstein summation, meanwhile, is good.
-
But why only Einstein summation? Why not Einstein everything?
-
And why not have the arrays remember which axis goes where, so you don't have to keep repeating that?
If so, you might like this package.
- Python 3.10+
- Numpy
- Jax
- varname (Optional: For magical axis naming.)
- Pandas (Optional: If you want to use
dataframe
)
- It's a single file:
numbat.py
- Download it and put it in your directory.
- Done.
First of all, you don't have to use it instead, you can use them together. Numbat is a different interface—all the real work is still done by Jax. You can start by using Numbat inside your existing Jax code, in whatever spots that makes things easier. All the standard Jax features still work (GPUs, JIT compilation, gradients, etc.) and interoperate smoothly.
OK, but when would Numbat make things easier? Well, in NumPy (and Jax and PyTorch), easy things are already easy, and Numbat will not help. But hard things are often really hard, because:
- Indexing gets insanely complicated and tedious.
- Broadcasting gets insanely complicated and tedious.
- Writing "batched" code gets insanely complicated and tedious.
Ultimately, these all stem from the same issue: Numpy indexes different axes by position. This leads to constant, endless fiddling to get the axes of different arrays to align with each other. It also means that different library functions all have their own (varying, and often poorly documented) conventions on where the different axes are supposed to go and what happens when arrays of different numbers of dimensions are provided.
Numbat is an experiment. What if axes didn't have positions, but only names? Sure, the bits have to be laid out in some order, but why make the user think about that? Following many previous projects, let's define the shape to be a dictionary that maps names to ints. But what if we're totally uncompromising and only allow indexing using names? And what if we redesign indexing and broadcasting and batching around that representation? Does something nice happen?
This is still just a prototype. But I think it's enough to validate that the answer is yes: Something very nice happens.
Say you've got some array X
containing data from different users, at different times and with different features. And you've got a few different subsets of users stored in my_users
. And for each user, there is some subset of times you care about, stored in my_times
. And for each user/time/subset combination, there is one feature you care about, stored in my_feats
.
(To be clear: X[u,t,f]
is the measurement of feature f
at time t
for user u
, my_users[i,k]
is user number i
in subset number k
, while my_times[j,i]
is the time for time number j
and user number i
, and my_feats[i,j,k]
is the feature you care about for user number i
at time number j
in subset number k
.)
So this is your situation:
X.shape == (n_user, n_time, n_feat)
my_users.shape == (100, 5)
my_times.shape == (20, 100)
my_feats.shape == (20, 5, 100)
You want to produce an array Z
such that for all combinations of i
, j
, and k
, the following is true:
Z[i,j,k] == X[my_users[i,k], my_times[j,i], my_feats[j,k,i]]
What's the easiest way to do that in NumPy? Obviously X[my_user, my_time, my_feat]
won't work. (Ha! Wouldn't that be nice!) In fact, the easiest answer turns out to be:
Z = X[my_users[:,None], my_times.T[:,:,None], my_feats.transpose(2,0,1)]
Urf.
Here's how to do this in Numbat. First, you cast all the arrays to be named tensors, by labeling the axes.
import numbat as nb
u, t, f = nb.axes()
x = nb.ntensor(X, u, t, f)
ny_users = nb.ntensor(my_users, u, f)
ny_times = nb.ntensor(my_times, t, u)
ny_feats = nb.ntensor(my_feats, t, f, u)
Then you index in the obvious way:
z = x(u=ny_users, t=ny_times, f=ny_feats)
That's it. That does what you want. Instead of (maddening, slow, tedious, error-prone) manual twiddling to line up the axes, you label them and then have the computer line them up for you. Computers are good at that.
Say that along with X
, we have some outputs Y
. For each user and each time, there is some vector of outputs we want to predict. We want to use dead-simple ridge regression, with one regression fit for each user, for each output, and for each of several different regularization constants R
.
To do this for a single user with a single output and a single regularization constant, remember the standard formula that
In this simple case, the code is a straightforward translation:
def simple_ridge(X, y, r):
n_time, n_feat = x.shape
n_time2, = y.shape
assert n_time == n_time2
w = np.linalg.solve(x.T @ x + r * np.eye(n_feat), x.T @ y)
return w
So here's the problem. You've got these three arrays:
X.shape == (n_user, n_time, n_feat)
Y.shape == (n_user, n_time, n_pred)
R.shape == (n_reg,)
And you'd like to compute some matrix W
that contains the results of
simple_ridge(X[u,:,:], Y[u,:,p], R[i])
for all u
, p
, and i
. How to do that in NumPy?
Well, do you know what numpy.linalg.solve(a, b)
does when a
and b
are high dimensional? The documentation is rather hard to parse. The simplest solution turns out to be:
def triple_batched_ridge(X, Y, R):
n_user, n_time, n_feat = X.shape
n_user2, n_time2, n_pred = Y.shape
assert n_user == n_user2
assert n_time == t_time2
XtX = np.sum(X.transpose(0,2,1)[:,:,:,None] * X[:,None,:,:], axis=2)
XtY = X.transpose(0,2,1) @ Y
W = np.linalg.solve(XtX[:,None,:,:] + R[None,:,None,None]*np.eye(n_feat), XtY[:,None,:,:])
return W
Urrrrf.
Even seeing this function, can you tell how the output is laid out? Where in W
does one find simple_ridge(X[u,:,:], Y[u,:,p], R[i])
? Would that be in W[u,p,i]
or W[i,:,p,u]
or what? The answer turns out to be W[u,r,:,i]
. Not because you want it there, but because of the vagaries of np.linag.solve
mean that's where it goes.
But say you don't want to manually batch things. An alternative is to ask jax.vmap
to do the batching for you. This is how you'd do that:
triple_batched_ridge_jax = \
jax.vmap(
jax.vmap(
jax.vmap(
simple_ridge_jax,
[None, 2, None]), # vmap Y over p
[0, 0, None]), # vmap X and Y over u
[None, None, 0]) # vmap R over r
W = triple_batched_ridge_jax(X, Y, R)
Simple enough, right? 🫡
Maybe. It's also completely wrong. The middle vmap
absorbs the first dimension of Y
, so in the innermost vmap
, p
is found in dimension 1
, not dimension 2
. (It's almost like referring to axes by position is confusing!) You also need to mess around with out_axes
if you want to reproduce the layout of the manually batched function.
So what you actually want is this:
triple_batched_ridge_jax = \
jax.vmap(
jax.vmap(
jax.vmap(
simple_ridge,
[None, 1, None], # vmap Y over p
out_axes=1), # yeehaw
[0, 0, None]), # vmap X and Y over u
[None, None, 0]) # vmap R over r
W = triple_batched_ridge_jax(X, Y, R)
Personally, I think this is much better than manual batching. But it still requires a lot of tedious manual tracking of axes as they flow through different operations.
So how would you do this in Numbat? Here's how:
u, t, f, p, i = nb.axes()
x = nb.ntensor(X, u, t, f)
y = nb.ntensor(Y, u, t, p)
r = nb.ntensor(R, i)
fun = nb.lift(simple_ridge, in_axes=[[t,f],[t],[]], out_axes=[f])
w = fun(x, y, r)
Yup, that's it. That works. The in_axes
argument tells lift
that simple_ridge
should operate on:
- A 2D array with axes
t
andf
. - A 1D array with axis
t
. - A scalar.
And the out_axes
says that it should return:
- A 1D array with axis
f
.
When fun
is finally called, the inputs x
, y
and r
all have named dimensions, so it knows exactly what it needs to do: It should operate on the t
and f
axes of x
and the t
axis of y
and place the output along the f
axis. Then it should broadcast over all other input dimensions.
And where does simple_ridge(X[u,:,:],Y[u,:,p],R[i])
end up? Well, it's in the only place it could be: w(u=u, p=p, r=i)
.
The above lift
syntax is a bit clunky. If you prefer, you could write fun=nb.lift(simple_ridge, 't f, t, -> f')
instead. This is completely equivalent.
If you don't want to learn a lot of features, you can (in principle) do everything with Numbat just using a few functions.
-
Use
ntensor
to create named tensors- Use
A=ntensor([[1,2,3],[4,5,6]],'i','j')
to create. - Use
A+B
, for (batched/broadcast) addition,A*B
for multiplication, etc. - Use
A.shape
to get the shape (a dict) - Use
A.axes
to get the axes (a set) - Use
A.ndim
to get the number of dimensions (an int). - Use
A(i=i_ind, j=j_ind)
to index. (Don't useA[i_ind, j_ind]
.) - Use
A.numpy('j', 'i')
to convert back to a regular Jax array.
- Use
-
Use
dot
to do inner/outer/matrix/tensor products or einstein summation.- Use
dot(A,B,C,D)
to sum along all shared axes. The order of the arguments does not matter! - Use
dot(A,B,C,D,keep={'i','j'})
to preserve some shared axes. A @ B
is equivalent todot(A,B)
.
- Use
-
Use
batch
to create a batched function- Use
batch(fun, {'i', 'j'})(A, B)
tofun
to the axesi
andj
ofA
andB
, broadcasting/batching over all other axes.
- Use
-
Use
vmap
to create a vmapped function.vmap(fun, {'i', 'j'})(A, B)
appliesfun
to all axes that exist in eitherA
orB
excepti
andj
, broadcasting/batching overi
andj
.
-
Use
lift
to wrap Jax functions to operate onntensor
s instead of Jax/NumPy arrays.- Use
fun = lift(jnp.matmul, 'i j, j k -> i k')
creates a function that usesi
andj
axes of the first argument and thej
andk
axes of the second argument. - Then,
fun(A,B)
is likentensor(jnp.matmul(A.numpy(i,j), B.numpy(j,k)),i,k)
, except it automatically broadcasts/vmaps over all input dimensions other thani
,j
, andk
.
- Use
-
Use
grad
andvalue_and_grad
to compute gradients.
API docs are at https://justindomke.github.io/numbat/
ntensor
is registered with Jax as a Pytree node, so things like jax.jit
and jax.tree_flatten
work with ntensor
s out of the box. For example, this is fine:
import jax
import numbat as nb
x = nb.ntensor([1.,2.,3.],'i')
def fun(x):
return nb.sum(x)
jax.jit(fun)(x) # works :)
Gradient functions like jax.grad
and jax.value_and_grad
also work out of the box, with one caveat: The output of the function to be a jax scalar, and not a ntensor
scalar. For example, this does not work:
import jax
import numbat as nb
x = nb.ntensor([1.,2.,3.],'i')
def fun(x):
return nb.sum(x)
jax.grad(fun)(x) # doesn't work :(
The problem is that the return value is an ntensor
with shape {}
, which jax.grad
doesn't know what to do with. You can fix this in two ways. First, you can convert a scalar ntensor
to a Jax scalar using the special .numpy()
syntax.:
import jax
import numbat as nb
x = nb.ntensor([1.,2.,3.],'i')
def fun(x):
out = nb.sum(x)
return out.numpy() # converts to jax scalar
jax.grad(fun)(x) # works!
Alternatively, you can use numbat.grad
wrapper which does the conversion for you.
import numbat as nb
x = nb.ntensor([1.,2.,3.],'i')
def fun(x):
return nb.sum(x)
nb.grad(fun)(x) # works!
jax.vmap
does not work. This is impossible since jax.vmap
is all based on the positions of axes. Use numbat.vmap
or numbat.batch
instead.
-
If you use the syntax
i,j,k = axes()
to createAxis
objects, this uses evil magic from thevarname
to try to figure out what the names ofi
,j
, andk
are. This package is kinda screwy and might give you errors likeVarnameRetrievingError
orCouldn't retrieve the call node
. If that happens, try reinstalling varname. Or just give the names explicitly, likei = Axis('i')
, etc. -
If you're using
jax.tree.*
utilities likejax.tree.map
these will by default descend into the numpy arrays stored inside ofntensor
objects. You can usejax.tree.map(..., ..., is_leaf=nb.is_ntensor)
to make surentensor
objects are considered leaves.
You can do broadcasting in three ways:
- You can use
vmap
:vmap(f, in_axes)(*args)
maps all arguments inargs
over all axes not inin_axes
.
- You can use
batch
:batch(f, axes)(*args)
will applyf
to the axes inaxes
, broadcasting and vmapping over everything else.
- You can use
wrap
:wrap(f)(*args, axes=axes)
is equivalent tobatch(f, axes)(*args)
wrap(f)(*args, vmap=in_axes)
is equivalent tovmap(f, in_axes)(*args)
- If you provide both
axes
andin_axes
then the function checks that all axes are included in one or the other.
-
xarray and the many efforts towards integration with Jax including
-
named tensors (in PyTorch)
-
Named Tensor Notation (for math)
(Please let me know about any other related packages.)