{ "cells": [ { "cell_type": "markdown", "id": "ea6ba877-3da2-4a5b-b878-8350e1c74b08", "metadata": {}, "source": [ "# JAX integration\n", "\n", "This notebook introduces the use of [JAX](https://jax.readthedocs.io) with tlm_adjoint.\n", "\n", "JAX can be used e.g. with the Firedrake backend as an escape-hatch, allowing custom operations to be implemented using lower level code. However it can also be used independently, without a finite element code generator. Here JAX is used with tlm_adjoint to implement and differentiate a finite difference code.\n", "\n", "## Forward problem\n", "\n", "We consider the diffusion equation in the unit square domain, subject to homogeneous Dirichlet boundary conditions,\n", "\n", "$$\\partial_t u = \\kappa \\left( \\partial_{xx} + \\partial_{yy} \\right) u + m \\left( x, y \\right) \\qquad \\text{on} ~ \\left( x, y \\right) \\in \\left( 0, 1 \\right)^2,$$\n", "$$u = 0 \\qquad \\text{on} ~ \\partial \\Omega.$$\n", "\n", "This is discretized using a basic finite difference scheme, using second order centered finite differencing for the $x$ and $y$ dimensions and forward Euler for the $t$ dimension.\n", "\n", "Here we implement the discretization using JAX, for $\\kappa = 0.01$, $t \\in \\left[ 0, 0.25 \\right]$, a time step size of $0.0025$, and a uniform grid with a grid spacing of $0.02$. We consider the case where $m \\left( x, y \\right) = 0$, and compute the $L^4$ norm of the $t = T$ solution (with the $L^4$ norm defined via nearest neighbour interpolation of the finite difference solution, ignoring boundary points where the solution is zero)." ] }, { "cell_type": "code", "execution_count": null, "id": "4d51cefb-f6e5-4178-bf9e-07469c6410ec", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "import jax\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "N = 50\n", "kappa = 0.01\n", "T = 0.25\n", "N_t = 100\n", "\n", "dx = 1.0 / N\n", "dt = T / N_t\n", " \n", " \n", "x = jax.numpy.linspace(0.0, 1.0, N + 1)\n", "y = jax.numpy.linspace(0.0, 1.0, N + 1)\n", "X, Y = jax.numpy.meshgrid(x, y, indexing=\"ij\")\n", "\n", "\n", "def timestep(x_n, m):\n", " x_np1 = jax.numpy.zeros_like(x_n)\n", " x_np1 = x_np1.at[1:-1, 1:-1].set(\n", " x_n[1:-1, 1:-1]\n", " + (kappa * dt / (dx * dx)) * (x_n[1:-1, 2:]\n", " + x_n[:-2, 1:-1]\n", " - 4.0 * x_n[1:-1, 1:-1]\n", " + x_n[2:, 1:-1]\n", " + x_n[1:-1, :-2])\n", " + dt * m[1:-1, 1:-1])\n", " return x_np1\n", "\n", "\n", "def functional(x_n):\n", " x_n = x_n.reshape((N + 1, N + 1))\n", " return (dx * dx * (x_n[1:-1, 1:-1] ** 4).sum()) ** 0.25\n", "\n", "\n", "def forward(x_0, m):\n", " x_n = x_0.copy()\n", " for _ in range(N_t):\n", " x_n = timestep(x_n, m)\n", " J = functional(x_n)\n", " return x_n, J\n", "\n", "\n", "x_0 = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)\n", "x_0 = x_0.at[1:-1, 1:-1].set(jax.numpy.exp(-((X[1:-1, 1:-1] - 0.5) ** 2 + (Y[1:-1, 1:-1]- 0.5) ** 2) / (2.0 * (0.05 ** 2))))\n", "m = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", "p = ax.contourf(x, y, x_0.T, 32)\n", "fig.colorbar(p)\n", "ax.set_aspect(1.0)\n", "ax.set_title(\"$x_0$\")\n", "\n", "x_n, J = forward(x_0, m)\n", "print(f\"{J=}\")\n", "\n", "fig, ax = plt.subplots(1, 1)\n", "p = ax.contourf(x, y, x_n.T, 32)\n", "fig.colorbar(p)\n", "ax.set_aspect(1.0)\n", "ax.set_title(\"$x_{N_t}$\")" ] }, { "cell_type": "markdown", "id": "3ff06fb2-5469-492f-95ce-7ce741375e30", "metadata": {}, "source": [ "## Adding tlm_adjoint\n", "\n", "The tlm_adjoint `Vector` class wraps ndim 1 JAX arrays. The following uses the tlm_adjoint `call_jax` function to record JAX operations on the internal tlm_adjoint manager. Note the use of `new_block`, indicating steps, and enabling use of a step-based checkpointing schedule." ] }, { "cell_type": "code", "execution_count": null, "id": "55f03da3", "metadata": {}, "outputs": [], "source": [ "from tlm_adjoint import *\n", "\n", "import jax\n", "\n", "reset_manager()\n", "\n", "N = 50\n", "kappa = 0.01\n", "T = 0.25\n", "N_t = 100\n", "\n", "dx = 1.0 / N\n", "dt = T / N_t\n", " \n", " \n", "x = jax.numpy.linspace(0.0, 1.0, N + 1)\n", "y = jax.numpy.linspace(0.0, 1.0, N + 1)\n", "X, Y = jax.numpy.meshgrid(x, y, indexing=\"ij\")\n", "\n", "\n", "def timestep(x_n, m):\n", " x_n = x_n.reshape((N + 1, N + 1))\n", " m = m.reshape((N + 1, N + 1))\n", " x_np1 = jax.numpy.zeros_like(x_n)\n", " x_np1 = x_np1.at[1:-1, 1:-1].set(\n", " x_n[1:-1, 1:-1]\n", " + (kappa * dt / (dx * dx)) * (x_n[1:-1, 2:]\n", " + x_n[:-2, 1:-1]\n", " - 4.0 * x_n[1:-1, 1:-1]\n", " + x_n[2:, 1:-1]\n", " + x_n[1:-1, :-2])\n", " + dt * m[1:-1, 1:-1])\n", " return x_np1.flatten()\n", "\n", "\n", "def functional(x_n):\n", " x_n = x_n.reshape((N + 1, N + 1))\n", " return (dx * dx * (x_n[1:-1, 1:-1] ** 4).sum()) ** 0.25\n", "\n", "\n", "def forward(x_0, m):\n", " x_n = Vector((N + 1) ** 2)\n", " x_np1 = Vector((N + 1) ** 2)\n", " x_n.assign(x_0)\n", " for n_t in range(N_t):\n", " call_jax(x_np1, (x_n, m), timestep)\n", " x_n.assign(x_np1)\n", " if n_t < N_t - 1:\n", " new_block()\n", " J = new_jax_float()\n", " call_jax(J, x_n, functional)\n", " return J\n", "\n", "\n", "x_0 = jax.numpy.zeros((N + 1, N + 1), dtype=jax.numpy.double)\n", "x_0 = x_0.at[1:-1, 1:-1].set(jax.numpy.exp(-((X[1:-1, 1:-1] - 0.5) ** 2 + (Y[1:-1, 1:-1]- 0.5) ** 2) / (2.0 * (0.05 ** 2))))\n", "x_0 = Vector(x_0.flatten())\n", "m = Vector((N + 1) ** 2)\n", "\n", "start_manager()\n", "J = forward(x_0, m)\n", "stop_manager()" ] }, { "cell_type": "markdown", "id": "5df893e2-47b5-4ba1-81c8-04542c344448", "metadata": {}, "source": [ "We can now verify first order tangent-linear and adjoint calculations, and a second order reverse-over-forward calculation, using Taylor remainder convergence tests. Here we consider derivatives with respect to $m$." ] }, { "cell_type": "code", "execution_count": null, "id": "33eb96c6-8934-4c2c-b356-7b29962842e6", "metadata": {}, "outputs": [], "source": [ "min_order = taylor_test_tlm(lambda m: forward(x_0, m), m, tlm_order=1)\n", "assert min_order > 1.99\n", "\n", "min_order = taylor_test_tlm_adjoint(lambda m: forward(x_0, m), m, adjoint_order=1)\n", "assert min_order > 1.99\n", "\n", "min_order = taylor_test_tlm_adjoint(lambda m: forward(x_0, m), m, adjoint_order=2)\n", "assert min_order > 1.99" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }