
🚛The old repository of this project (ndimsplinejax@nmoteki) has moved here as my old github account nmoteki is no longer maintained!
Table of Contents
Interpolant is an efficiently-computable mathematical function that models a discrete dataset. Interpolant is an indispensable tool for mathmatically incorpolating observational data to physical simulations or statistical inferences without appreciable biases. This can be contrast to regression models (e.g., multilayer perceptron) that almost always suffers from under-or-over fitting issue to some extent.
There have been many interpolation code/software available; however, I didn't find any multidimensional interpolant compatible with both Just-In-Time compilation and Automatic Differentiation when I starded this project in mid 2022. In my research, I needed such interpolant for applying a recent Hamiltonian-MC code to my Bayesian inverse problem wherein the forward model is only accessible through a pre-computed discrete look-up table. In that case, the forward model, a light-scattering simulator for nonspherical particles, is computationally too complex to execute in place. So, I decided to develop this NdimSpline_JAX. I'd like to share the codes hoping they are useful for scientists and engineers.
SplineCoefs_from_GriddedData
module computes the natural-cubic spline coefficients of the interpolant from the scalar y data distributed on a N-dimensional Cartesian x grid.SplineInterpolant
module generates an JIT & Autograd compatible interpolant from the spline coefficients.- On each dimensional axis, x grid-points must be equidistant. The grid-points interval can be different among axes.
- Current version supports 1, 2, 3, 4, 5 dimensional x space (N<=5).
- The author thinks the requirement of "equidistant grid-points on each axis" would not be a serious limitation in practice. A user can project/approximate a non-equidistant gridded data to equidistant gridded data by an mathematical transformation of each variable.
- The code execution in higher dimensions/finer grids can be restricted by affordable memory and the computation time.
This is an example of how you use the modules on your local computer. The author tested the codes using Python 3.12.8 on Windows 11 machine and Python 3.12.9 on WSL (Ubuntu).
- An execution enviroment of Python >= 3.12 on Linux, MacOS, or WSL2 on Windows
- Installation of
jax
module, and optionallyipykernel
module if you execute JupyterNotebook files.
git clone https://github.com/NobuhiroMoteki/NdimSpline_JAX.git
Here is the workflow for an example of 5-dimensional x-space (N=5):
-
Define the grid information. For x-coordinates, we define the N-list of lower bounds
a
, the N-listb
of upper bounds, and the N-listn
of number of grid intervals.a= [0,0,0,0,0] # the user-defined lower bound of each x-coordinate [1st dim, ..., Nth dim] b= [1,2,3,4,5] # the user-defined upper bound of each x-coordinate [1st dim, ..., Nth dim] n= [10,10,10,10,10] # the user-defined number of grid intervals in each x-coordinate [1st dim, ..., Nth dim]
-
Prepare an observation data
y_data
on the x gridpoints.N= len(a) # dimension N # Make an N-tuple of numpy arrays of x-gridpoint values x_grid= () for j in range(N): x_grid += (np.linspace(a[j],b[j],n[j]+1),) # Make an N-dimensional numpy array of y_data grid_shape= () for j in range(N): grid_shape += (n[j]+1,) y_data= np.zeros(grid_shape) # A synthetic y_data (should be replaced by a user-defined data in actual use): for q1 in range(n[0]+1): for q2 in range(n[1]+1): for q3 in range(n[2]+1): for q4 in range(n[3]+1): for q5 in range(n[4]+1): y_data[q1,q2,q3,q4,q5]= np.sin(x_grid[0][q1])*np.sin(x_grid[1][q2])*np.sin(x_grid[2][q3])*np.sin(x_grid[3][q4])*np.sin(x_grid[4][q5])
-
Compute the spline coefficients from data, using the
SplineCoefs_from_GriddedData
module.# import the module. from SplineCoefs_from_GriddedData import SplineCoefs_from_GriddedData # Make an instance of the class SplineCoefs_from_GriddedData spline_coef= SplineCoefs_from_GriddedData(a,b,y_data) # Compute the spline coeffcients c_i1...iN (The author recommend a name of the coefficients matrix to be N-explicit for readability) c_i1i2i3i4i5= spline_coef.Compute_Coefs()
-
Generate the JIT & AD -able interpolant from the coefficients, using the
SplineInterpolant
module.# import the module. from SplineInterpolant import SplineInterpolant # compute the jittable and auto-differentiable interpolant using the spline coeffcient c_i1i2i3i4i5. spline= SplineInterpolant(a,b,n,c_i1i2i3i4i5)
-
Use the generated interpolant with the
jax
's JIT & Autograd functionalities.import jax.numpy as jnp from jax import jit, grad, value_and_grad # Specify a x-coordinate for function evaluation as a jnp array. x= jnp.array([0.7,1.0,1.5,2.0,2.5]) # By definition, x must satisfy the elementwise inequality a <= x <= b. # call the method of 5-dimentional interpolant s5D of the "spline" instance (without JIT) print(spline.s5D(x)) # for N-dimension, please call sND method (N is either of 1,2,3,4,5) # Compute the automatic gradient of spline.s5D(x) at the specified x-coordinate ds5D= grad(spline.s5D) print(ds5D(x)) # Compute both value and gradient of spline.s5D(x) at the specified x-coordinate s5D_fun= value_and_grad(spline.s5D) print(s5D_fun(x)) # Jitted verison of spline.s5D(x) at the specified x-coordinate s5D_jitted= jit(spline.s5D) print(s5D_jitted(x)) # Compute the jitted automatic gradient of spline.s5D(x) at the specified x-coordinate ds5D_jitted= jit(grad(spline.s5D)) print(ds5D_jitted(x)) s5D_fun_jitted= jit(value_and_grad(spline.s5D)) print(s5D_fun_jitted(x))
-
Compare the computation time of spline interpolant between non-Jitted and Jitted versions.
%timeit spline.s5D(x) # function evaluation %timeit s5D_jitted(x) # function evaluation (jitted) %timeit ds5D(x) # gradient evaluation %timeit ds5D_jitted(x) # gradient evaluation (jitted) %timeit s5D_fun(x) # function and it's gradient evaluation %timeit s5D_fun_jitted(x) # function and it's gradient evaluation (jitted)
The jitted version will be faster by 1-3 orders of magnitude than non-jitted version. I observed that the difference was more drastic at smaller n values (i.e., number of grid points).
For executing this example, just run the caller.ipynb
on JupyterNotebook or excetute the caller.py
script.
The ./jupyter_notebooks
subfolder contains .ipynb
files scripting the individual dimensional cases. These files would help user's understandings or customizations.
- Maths of multidimensional natural-cubic spline interpolation: Habermann and Kindermann 2007, Multidimensional Spline Interpolation: Theory and Applications, DOI: 10.1007/s10614-007-9092-4.
- Google/JAX reference documentation: https://jax.readthedocs.io/en/latest/
- An introduction of Google/JAX for scientists (in Japanese): https://github.com/HajimeKawahara/playjax
Distributed under the MIT License. See LICENSE.txt
for more information.
Nobuhiro Moteki - [email protected]
Project Link: https://github.com/nmoteki/ndimsplinejax.git
This code-development project was conceived and proceeded in a part of the N.Moteki's research on atmospheric chemical composition in the NOAA Earth System Science Laboratory, supported by a fund JSPS KAKENIHI 19KK0289.