JAX integration

This notebook introduces the use of JAX with tlm_adjoint.

JAX can be used e.g. with the Firedrake backend as an escape-hatch, allowing custom operations to be implemented using lower level code. However it can also be used independently, without a finite element code generator. Here JAX is used with tlm_adjoint to implement and differentiate a finite difference code.

Forward problem

We consider the diffusion equation in the unit square domain, subject to homogeneous Dirichlet boundary conditions,

\[\partial_t u = \kappa \left( \partial_{xx} + \partial_{yy} \right) u + m \left( x, y \right) \qquad \text{on} ~ \left( x, y \right) \in \left( 0, 1 \right)^2,\]
\[u = 0 \qquad \text{on} ~ \partial \Omega.\]

This is discretized using a basic finite difference scheme, using second order centered finite differencing for the \(x\) and \(y\) dimensions and forward Euler for the \(t\) dimension.

Here we implement the discretization using JAX, for \(\kappa = 0.01\), \(t \in \left[ 0, 0.25 \right]\), a time step size of \(0.0025\), and a uniform grid with a grid spacing of \(0.02\). We consider the case where \(m \left( x, y \right) = 0\), and compute the \(L^4\) norm of the \(t = T\) solution (with the \(L^4\) norm defined via nearest neighbour interpolation of the finite difference solution, ignoring boundary points where the solution is zero).

[1]:
%matplotlib inline

import matplotlib.pyplot as plt
import jax

jax.config.update("jax_enable_x64", True)

N = 50
kappa = 0.01
T = 0.25
N_t = 100

dx = 1.0 / N
dt = T / N_t


x = jax.numpy.linspace(0.0, 1.0, N + 1)
y = jax.numpy.linspace(0.0, 1.0, N + 1)
X, Y = jax.numpy.meshgrid(x, y, indexing="ij")


def timestep(x_n, m):
    x_np1 = jax.numpy.zeros_like(x_n)
    x_np1 = x_np1.at[1:-1, 1:-1].set(
        x_n[1:-1, 1:-1]
        + (kappa * dt / (dx * dx)) * (x_n[1:-1, 2:]
                                      + x_n[:-2, 1:-1]
                                      - 4.0 * x_n[1:-1, 1:-1]
                                      + x_n[2:, 1:-1]
                                      + x_n[1:-1, :-2])
        + dt * m[1:-1, 1:-1])
    return x_np1


def functional(x_n):
    x_n = x_n.reshape((N + 1, N + 1))
    return (dx * dx * (x_n[1:-1, 1:-1] ** 4).sum()) ** 0.25


def forward(x_0, m):
    x_n = x_0.copy()
    for _ in range(N_t):
        x_n = timestep(x_n, m)
    J = functional(x_n)
    return x_n, J


x_0 = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)
x_0 = x_0.at[1:-1, 1:-1].set(jax.numpy.exp(-((X[1:-1, 1:-1] - 0.5) ** 2 + (Y[1:-1, 1:-1]- 0.5) ** 2) / (2.0 * (0.05 ** 2))))
m = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)

fig, ax = plt.subplots(1, 1)
p = ax.contourf(x, y, x_0.T, 32)
fig.colorbar(p)
ax.set_aspect(1.0)
ax.set_title("$x_0$")

x_n, J = forward(x_0, m)
print(f"{J=}")

fig, ax = plt.subplots(1, 1)
p = ax.contourf(x, y, x_n.T, 32)
fig.colorbar(p)
ax.set_aspect(1.0)
ax.set_title("$x_{N_t}$")
J=Array(0.11009761, dtype=float64)
[1]:
Text(0.5, 1.0, '$x_{N_t}$')
../_images/examples_7_jax_integration_1_2.png
../_images/examples_7_jax_integration_1_3.png

Adding tlm_adjoint

The tlm_adjoint Vector class wraps ndim 1 JAX arrays. The following uses the tlm_adjoint call_jax function to record JAX operations on the internal tlm_adjoint manager. Note the use of new_block, indicating steps, and enabling use of a step-based checkpointing schedule.

[2]:
from tlm_adjoint import *

import jax

reset_manager()

N = 50
kappa = 0.01
T = 0.25
N_t = 100

dx = 1.0 / N
dt = T / N_t


x = jax.numpy.linspace(0.0, 1.0, N + 1)
y = jax.numpy.linspace(0.0, 1.0, N + 1)
X, Y = jax.numpy.meshgrid(x, y, indexing="ij")


def timestep(x_n, m):
    x_n = x_n.reshape((N + 1, N + 1))
    m = m.reshape((N + 1, N + 1))
    x_np1 = jax.numpy.zeros_like(x_n)
    x_np1 = x_np1.at[1:-1, 1:-1].set(
        x_n[1:-1, 1:-1]
        + (kappa * dt / (dx * dx)) * (x_n[1:-1, 2:]
                                      + x_n[:-2, 1:-1]
                                      - 4.0 * x_n[1:-1, 1:-1]
                                      + x_n[2:, 1:-1]
                                      + x_n[1:-1, :-2])
        + dt * m[1:-1, 1:-1])
    return x_np1.flatten()


def functional(x_n):
    x_n = x_n.reshape((N + 1, N + 1))
    return (dx * dx * (x_n[1:-1, 1:-1] ** 4).sum()) ** 0.25


def forward(x_0, m):
    x_n = Vector((N + 1) ** 2)
    x_np1 = Vector((N + 1) ** 2)
    x_n.assign(x_0)
    for n_t in range(N_t):
        call_jax(x_np1, (x_n, m), timestep)
        x_n.assign(x_np1)
        if n_t < N_t - 1:
            new_block()
    J = new_jax_float()
    call_jax(J, x_n, functional)
    return J


x_0 = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)
x_0 = x_0.at[1:-1, 1:-1].set(jax.numpy.exp(-((X[1:-1, 1:-1] - 0.5) ** 2 + (Y[1:-1, 1:-1]- 0.5) ** 2) / (2.0 * (0.05 ** 2))))
x_0 = Vector(x_0.flatten())
m = Vector((N + 1) ** 2)

start_manager()
J = forward(x_0, m)
stop_manager()
[2]:
(True, True)

We can now verify first order tangent-linear and adjoint calculations, and a second order reverse-over-forward calculation, using Taylor remainder convergence tests. Here we consider derivatives with respect to \(m\).

[3]:
min_order = taylor_test_tlm(lambda m: forward(x_0, m), m, tlm_order=1)
assert min_order > 1.99

min_order = taylor_test_tlm_adjoint(lambda m: forward(x_0, m), m, adjoint_order=1)
assert min_order > 1.99

min_order = taylor_test_tlm_adjoint(lambda m: forward(x_0, m), m, adjoint_order=2)
assert min_order > 1.99
Error norms, no tangent-linear   = [5.54398508e-04 2.77033803e-04 1.38475637e-04 6.92275132e-05
 3.46111816e-05]
Orders,      no tangent-linear   = [1.00086136 1.00042985 1.00021474 1.00010733]
Error norms, with tangent-linear = [6.60788377e-07 1.64942869e-07 4.12066718e-08 1.02982105e-08
 2.57413140e-09]
Orders,      with tangent-linear = [2.0022219  2.00101656 2.00048428 2.00023606]
Error norms, no adjoint   = [5.46708909e-04 2.73211003e-04 1.36569730e-04 6.82759325e-05
 3.41357343e-05]
Orders,      no adjoint   = [1.0007573  1.00037783 1.00018874 1.00009433]
Error norms, with adjoint = [5.72858211e-07 1.42977754e-07 3.57175399e-08 8.92619386e-09
 2.23116036e-09]
Orders,      with adjoint = [2.00238741 2.00108607 2.00051567 2.00025093]
Error norms, no adjoint   = [1.15618066e-04 5.76187476e-05 2.87665245e-05 1.43731502e-05
 7.18412279e-06]
Orders,      no adjoint   = [1.00475664 1.00214738 1.00101463 1.00049239]
Error norms, with adjoint = [7.10528022e-07 1.64978412e-07 3.96398721e-08 9.70792438e-09
 2.40163446e-09]
Orders,      with adjoint = [2.10661429 2.05725304 2.02971753 2.01514629]