JAX integration
This notebook introduces the use of JAX with tlm_adjoint.
JAX can be used e.g. with the Firedrake backend as an escape-hatch, allowing custom operations to be implemented using lower level code. However it can also be used independently, without a finite element code generator. Here JAX is used with tlm_adjoint to implement and differentiate a finite difference code.
Forward problem
We consider the diffusion equation in the unit square domain, subject to homogeneous Dirichlet boundary conditions,
This is discretized using a basic finite difference scheme, using second order centered finite differencing for the \(x\) and \(y\) dimensions and forward Euler for the \(t\) dimension.
Here we implement the discretization using JAX, for \(\kappa = 0.01\), \(t \in \left[ 0, 0.25 \right]\), a time step size of \(0.0025\), and a uniform grid with a grid spacing of \(0.02\). We consider the case where \(m \left( x, y \right) = 0\), and compute the \(L^4\) norm of the \(t = T\) solution (with the \(L^4\) norm defined via nearest neighbour interpolation of the finite difference solution, ignoring boundary points where the solution is zero).
[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import jax
jax.config.update("jax_enable_x64", True)
N = 50
kappa = 0.01
T = 0.25
N_t = 100
dx = 1.0 / N
dt = T / N_t
x = jax.numpy.linspace(0.0, 1.0, N + 1)
y = jax.numpy.linspace(0.0, 1.0, N + 1)
X, Y = jax.numpy.meshgrid(x, y, indexing="ij")
def timestep(x_n, m):
x_np1 = jax.numpy.zeros_like(x_n)
x_np1 = x_np1.at[1:-1, 1:-1].set(
x_n[1:-1, 1:-1]
+ (kappa * dt / (dx * dx)) * (x_n[1:-1, 2:]
+ x_n[:-2, 1:-1]
- 4.0 * x_n[1:-1, 1:-1]
+ x_n[2:, 1:-1]
+ x_n[1:-1, :-2])
+ dt * m[1:-1, 1:-1])
return x_np1
def functional(x_n):
x_n = x_n.reshape((N + 1, N + 1))
return (dx * dx * (x_n[1:-1, 1:-1] ** 4).sum()) ** 0.25
def forward(x_0, m):
x_n = x_0.copy()
for _ in range(N_t):
x_n = timestep(x_n, m)
J = functional(x_n)
return x_n, J
x_0 = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)
x_0 = x_0.at[1:-1, 1:-1].set(jax.numpy.exp(-((X[1:-1, 1:-1] - 0.5) ** 2 + (Y[1:-1, 1:-1]- 0.5) ** 2) / (2.0 * (0.05 ** 2))))
m = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)
fig, ax = plt.subplots(1, 1)
p = ax.contourf(x, y, x_0.T, 32)
fig.colorbar(p)
ax.set_aspect(1.0)
ax.set_title("$x_0$")
x_n, J = forward(x_0, m)
print(f"{J=}")
fig, ax = plt.subplots(1, 1)
p = ax.contourf(x, y, x_n.T, 32)
fig.colorbar(p)
ax.set_aspect(1.0)
ax.set_title("$x_{N_t}$")
J=Array(0.11009761, dtype=float64)
[1]:
Text(0.5, 1.0, '$x_{N_t}$')
Adding tlm_adjoint
The tlm_adjoint Vector
class wraps ndim 1 JAX arrays. The following uses the tlm_adjoint call_jax
function to record JAX operations on the internal tlm_adjoint manager. Note the use of new_block
, indicating steps, and enabling use of a step-based checkpointing schedule.
[2]:
from tlm_adjoint import *
import jax
reset_manager()
N = 50
kappa = 0.01
T = 0.25
N_t = 100
dx = 1.0 / N
dt = T / N_t
x = jax.numpy.linspace(0.0, 1.0, N + 1)
y = jax.numpy.linspace(0.0, 1.0, N + 1)
X, Y = jax.numpy.meshgrid(x, y, indexing="ij")
def timestep(x_n, m):
x_n = x_n.reshape((N + 1, N + 1))
m = m.reshape((N + 1, N + 1))
x_np1 = jax.numpy.zeros_like(x_n)
x_np1 = x_np1.at[1:-1, 1:-1].set(
x_n[1:-1, 1:-1]
+ (kappa * dt / (dx * dx)) * (x_n[1:-1, 2:]
+ x_n[:-2, 1:-1]
- 4.0 * x_n[1:-1, 1:-1]
+ x_n[2:, 1:-1]
+ x_n[1:-1, :-2])
+ dt * m[1:-1, 1:-1])
return x_np1.flatten()
def functional(x_n):
x_n = x_n.reshape((N + 1, N + 1))
return (dx * dx * (x_n[1:-1, 1:-1] ** 4).sum()) ** 0.25
def forward(x_0, m):
x_n = Vector((N + 1) ** 2)
x_np1 = Vector((N + 1) ** 2)
x_n.assign(x_0)
for n_t in range(N_t):
call_jax(x_np1, (x_n, m), timestep)
x_n.assign(x_np1)
if n_t < N_t - 1:
new_block()
J = new_jax_float()
call_jax(J, x_n, functional)
return J
x_0 = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)
x_0 = x_0.at[1:-1, 1:-1].set(jax.numpy.exp(-((X[1:-1, 1:-1] - 0.5) ** 2 + (Y[1:-1, 1:-1]- 0.5) ** 2) / (2.0 * (0.05 ** 2))))
x_0 = Vector(x_0.flatten())
m = Vector((N + 1) ** 2)
start_manager()
J = forward(x_0, m)
stop_manager()
/home/maddison/build/firedrake/firedrake/lib/python3.10/site-packages/pytools/persistent_dict.py:59: UserWarning: Unable to import recommended hash 'siphash24.siphash13', falling back to 'hashlib.sha256'. Run 'python3 -m pip install siphash24' to install the recommended hash.
warn("Unable to import recommended hash 'siphash24.siphash13', "
[2]:
(True, True)
We can now verify first order tangent-linear and adjoint calculations, and a second order reverse-over-forward calculation, using Taylor remainder convergence tests. Here we consider derivatives with respect to \(m\).
[3]:
min_order = taylor_test_tlm(lambda m: forward(x_0, m), m, tlm_order=1)
assert min_order > 1.99
min_order = taylor_test_tlm_adjoint(lambda m: forward(x_0, m), m, adjoint_order=1)
assert min_order > 1.99
min_order = taylor_test_tlm_adjoint(lambda m: forward(x_0, m), m, adjoint_order=2)
assert min_order > 1.99
Error norms, no tangent-linear = [5.21231850e-04 2.60487525e-04 1.30211754e-04 6.50978855e-05
3.25469460e-05]
Orders, no tangent-linear = [1.00071096 1.00035459 1.0001771 1.00008851]
Error norms, with tangent-linear = [5.12649194e-07 1.27924571e-07 3.19543025e-08 7.98540265e-09
1.99596553e-09]
Orders, with tangent-linear = [2.00267853 2.00121129 2.00057314 2.00027835]
Error norms, no adjoint = [5.71007638e-04 2.85359931e-04 1.42644089e-04 7.13130862e-05
3.56543048e-05]
Orders, no adjoint = [1.00072727 1.0003628 1.00018122 1.00009057]
Error norms, with adjoint = [5.74561522e-07 1.43392851e-07 3.58202146e-08 8.95174055e-09
2.23753281e-09]
Orders, with adjoint = [2.00248828 2.00112722 2.00053385 2.00025938]
Error norms, no adjoint = [9.62227787e-05 4.79560268e-05 2.39441218e-05 1.19642178e-05
5.98022754e-06]
Orders, no adjoint = [1.00466633 1.00204061 1.00094545 1.00045379]
Error norms, with adjoint = [5.68388145e-07 1.28831533e-07 3.05241525e-08 7.41895936e-09
1.82812874e-09]
Orders, with adjoint = [2.14139071 2.07746261 2.04066249 2.02084917]