# Time-dependent example

This notebook describes the calculation of derivative information for a time-dependent problem using tlm_adjoint with the [Firedrake](https://firedrakeproject.org/) backend. Overheads associated with building the records of calculations are discussed, and a checkpointing schedule is applied.

The binomial checkpointing schedule is based on the method described in:

- Andreas Griewank and Andrea Walther, 'Algorithm 799: revolve: an implementation of checkpointing for the reverse or adjoint mode of computational differentiation', ACM Transactions on Mathematical Software, 26(1), pp. 19–45, 2000, doi: 10.1145/347837.347846

## Forward problem

We consider the solution of a linear time-dependent partial differential equation, followed by the calculation of the square of the $L^2$-norm of the final time solution. We assume real spaces and a real build of Firedrake throughout.

Specifically we consider the advection-diffusion equation in two dimensions, in the form

$$
 \partial_t u + \partial_x \psi \partial_y u - \partial_y \psi \partial_x u = \kappa \left( \partial_{xx} + \partial_{yy} \right) u,
$$

where $\psi$ vanishes on the domain boundary, and subject to zero flux boundary conditions. We consider the spatial domain $\left( x, y \right) \in \left( 0, 1 \right)^2$ and temporal domain $t \in \left[ 0, 0.1 \right]$, with $\psi \left( x, y \right) = -\sin \left( \pi x \right) \sin \left( \pi y \right)$ and $\kappa = 0.01$, and an initial condition $u \left( x, y, t=0 \right) = \exp \left[ -50 \left( \left( x - 0.75 \right)^2 + \left( y - 0.5 \right)^2 \right) \right]$.

The problem is discretized using $P_1$ continuous finite elements to represent both the solution $u$ at each time level and the stream function $\psi$. The problem is discretized in time using the implicit trapezoidal rule.

A simple implementation in Firedrake takes the form:

In [None]:
%matplotlib inline

from firedrake import *
from firedrake.pyplot import tricontourf

import matplotlib.pyplot as plt
import numpy as np

T = 0.1
N = 100
dt = Constant(T / N)

mesh = UnitSquareMesh(128, 128)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

psi = Function(space, name="psi")
psi.interpolate(-sin(pi * X[0]) * sin(pi * X[1]))

kappa = Constant(0.01)

u_0 = Function(space, name="u_0")
u_0.interpolate(exp(-50.0 * ((X[0] - 0.75) ** 2 + (X[1] - 0.5) ** 2)))

u_n = Function(space, name="u_n")
u_np1 = Function(space, name="u_np1")

u_h = 0.5 * (u_n + trial)
F = (inner(trial - u_n, test) * dx
 + dt * inner(psi.dx(0) * u_h.dx(1) - psi.dx(1) * u_h.dx(0), test) * dx
 + dt * inner(kappa * grad(u_h), grad(test)) * dx)
lhs, rhs = system(F)

problem = LinearVariationalProblem(
 lhs, rhs, u_np1,
 constant_jacobian=True)
solver = LinearVariationalSolver(
 problem, solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"})

u_n.assign(u_0)
for n in range(N):
 solver.solve()
 u_n.assign(u_np1)

J = assemble(inner(u_n, u_n) * dx)


def plot_output(u, title):
 r = (u.dat.data_ro.min(), u.dat.data_ro.max())
 eps = (r[1] - r[0]) * 1.0e-12
 p = tricontourf(u, np.linspace(r[0] - eps, r[1] + eps, 32))
 plt.gca().set_title(title)
 plt.colorbar(p)
 plt.gca().set_aspect(1.0)


plot_output(u_0, title="$u_0$")
plot_output(u_n, title="$u_n$")

## Adding tlm_adjoint

We first modify the code so that tlm_adjoint processes the calculations:

In [None]:
from firedrake import *
from tlm_adjoint.firedrake import *

reset_manager("memory", {})

T = 0.1
N = 100
dt = Constant(T / N)

mesh = UnitSquareMesh(128, 128)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

psi = Function(space, name="psi")
psi.interpolate(-sin(pi * X[0]) * sin(pi * X[1]))

kappa = Constant(0.01)

u_0 = Function(space, name="u_0")
u_0.interpolate(exp(-50.0 * ((X[0] - 0.75) ** 2 + (X[1] - 0.5) ** 2)))


def forward(u_0, psi):
 u_n = Function(space, name="u_n")
 u_np1 = Function(space, name="u_np1")

 u_h = 0.5 * (u_n + trial)
 F = (inner(trial - u_n, test) * dx
 + dt * inner(psi.dx(0) * u_h.dx(1) - psi.dx(1) * u_h.dx(0), test) * dx
 + dt * inner(kappa * grad(u_h), grad(test)) * dx)
 lhs, rhs = system(F)

 problem = LinearVariationalProblem(
 lhs, rhs, u_np1,
 constant_jacobian=True)
 solver = LinearVariationalSolver(
 problem, solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"})

 u_n.assign(u_0)
 for n in range(N):
 solver.solve()
 u_n.assign(u_np1)

 J = Functional(name="J")
 J.assign(inner(u_n, u_n) * dx)
 return J


start_manager()
J = forward(u_0, psi)
stop_manager()

Later we will configure a checkpointing schedule. Resetting the manager resets the record of forward equations but does not reset the checkpointing configuration, and so in this example whenever we reset the manager we also return it to the default checkpointing configuration with `reset_manager("memory", {})`.

## Computing derivatives using an adjoint

The `compute_gradient` function can be used to compute derivatives using the adjoint method. Here we compute the derivative of the square of the $L^2$-norm of the final timestep solution, considered a function of the control defined by the initial condition `u_0` and stream function `psi`, with respect to this control:

In [None]:
dJ_du_0, dJ_dpsi = compute_gradient(J, (u_0, psi))

As a simple check of the result, note that the solution to the (discretized) partial differential equation is unchanged by the addition of a constant to the stream function. Hence we expect the directional derivative with respect to the stream function, with direction equal to the unity valued function, to be zero. This is indeed found to be the case (except for roundoff errors):

In [None]:
one = Function(space, name="one")
one.interpolate(Constant(1.0))

dJ_dpsi_one = var_inner(one, dJ_dpsi)

print(f"{dJ_dpsi_one=}")

assert abs(dJ_dpsi_one) < 1.0e-17

## Computing Hessian information using an adjoint of a tangent-linear

We next compute a Hessian action. Although the following calculation does work, it is inefficient – you may wish to skip forward to the optimized calculations.

Here we compute a 'mixed' Hessian action, by defining a directional derivative with respect to the stream function, and then differentiating this with respect to the initial condition:

In [None]:
from firedrake import *
from tlm_adjoint.firedrake import *

reset_manager("memory", {})

T = 0.1
N = 100
dt = Constant(T / N)

mesh = UnitSquareMesh(128, 128)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

psi = Function(space, name="psi")
psi.interpolate(-sin(pi * X[0]) * sin(pi * X[1]))

kappa = Constant(0.01)

u_0 = Function(space, name="u_0")
u_0.interpolate(exp(-50.0 * ((X[0] - 0.75) ** 2 + (X[1] - 0.5) ** 2)))


def forward(u_0, psi):
 u_n = Function(space, name="u_n")
 u_np1 = Function(space, name="u_np1")

 u_h = 0.5 * (u_n + trial)
 F = (inner(trial - u_n, test) * dx
 + dt * inner(psi.dx(0) * u_h.dx(1) - psi.dx(1) * u_h.dx(0), test) * dx
 + dt * inner(kappa * grad(u_h), grad(test)) * dx)
 lhs, rhs = system(F)

 problem = LinearVariationalProblem(
 lhs, rhs, u_np1,
 constant_jacobian=True)
 solver = LinearVariationalSolver(
 problem, solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"})

 u_n.assign(u_0)
 for n in range(N):
 solver.solve()
 u_n.assign(u_np1)

 J = Functional(name="J")
 J.assign(inner(u_n, u_n) * dx)
 return J


zeta = Function(space, name="zeta")
zeta.assign(psi)
configure_tlm((psi, zeta))

start_manager()
J = forward(u_0, psi)
stop_manager()

dJ_dpsi_zeta = var_tlm(J, (psi, zeta))

d2J_dpsi_zeta_du_0 = compute_gradient(dJ_dpsi_zeta, u_0)

## Optimization

In the above we have successfully built a record of calculations, and used this to compute derivative information. However there are two issues:

1. Building the record has a noticable cost – the forward calculation has slowed down. In the second order calculation overheads associated with the tangent-linear lead to substantial additional costs.
2. tlm_adjoint records the solution of the partial differential equation on all time levels. The memory usage here is manageable. However memory limits will be exceeded for larger problems with more fields, spatial degrees of freedom, or timesteps.

Let's fix these issues in order.

### Optimizing the annotation

In the above code tlm_adjoint builds a new record for each finite element variational problem it encounters. Even though only one `LinearVariationalSolver` is instantiated, an `EquationSolver` record is instantiated on each call to the `solve` method. Building the record is sufficiently expensive that the forward calculation noticeably slows down, and this also leads to significant extra processing in the derivative calculations.

Instead we can instantiate an `EquationSolver` directly, and reuse it. However if we do only that then the code will still be inefficient. A single `EquationSolver` will be used, but new linear solver data will be constructed each time its `solve` method is called. We need to also apply an optimization analogous to the `constant_jacobian=True` argument supplied to `LinearVariationalProblem`.

A simple fix is to add `cache_jacobian=True` when instantiating the `EquationSolver`:

```
eq = EquationSolver(
 lhs == rhs, u_np1,
 solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"},
 cache_jacobian=True)
```

This works, but we can instead let tlm_adjoint detect that linear solver data can be cached. We can do that by adding `static=True` when instantiating variables whose value is unchanged throughout the forward calculation:

In [None]:
from firedrake import *
from tlm_adjoint.firedrake import *

reset_manager("memory", {})
clear_caches()

T = 0.1
N = 100
dt = Constant(T / N, static=True)

mesh = UnitSquareMesh(128, 128)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

psi = Function(space, name="psi", static=True)
psi.interpolate(-sin(pi * X[0]) * sin(pi * X[1]))

kappa = Constant(0.01, static=True)

u_0 = Function(space, name="u_0", static=True)
u_0.interpolate(exp(-50.0 * ((X[0] - 0.75) ** 2 + (X[1] - 0.5) ** 2)))


def forward(u_0, psi):
 u_n = Function(space, name="u_n")
 u_np1 = Function(space, name="u_np1")

 u_h = 0.5 * (u_n + trial)
 F = (inner(trial - u_n, test) * dx
 + dt * inner(psi.dx(0) * u_h.dx(1) - psi.dx(1) * u_h.dx(0), test) * dx
 + dt * inner(kappa * grad(u_h), grad(test)) * dx)
 lhs, rhs = system(F)

 eq = EquationSolver(
 lhs == rhs, u_np1,
 solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"})

 u_n.assign(u_0)
 for n in range(N):
 eq.solve()
 u_n.assign(u_np1)

 J = Functional(name="J")
 J.assign(inner(u_n, u_n) * dx)
 return J


start_manager()
J = forward(u_0, psi)
stop_manager()

If we now query the relevant tlm_adjoint caches:

In [None]:
print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")

assert len(assembly_cache()) == 2
assert len(linear_solver_cache()) == 1

we find that linear solver data associated with a single matrix has been cached. We also find that two assembled objects have been cached – it turns out that there are two cached matrices. As well as caching the matrix associated with the left-hand-side of the discrete problem, a matrix associated with the *right-hand-side* has been assembled and cached. Assembly of the right-hand-side has been converted into a matrix multiply. If we wished we could disable right-hand-side optimizations by adding `cache_rhs_assembly=False`:

```
eq = EquationSolver(
 lhs == rhs, u_np1,
 solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"},
 cache_rhs_assembly=False)
```

### Using a checkpointing schedule

To address the storage issue we enable checkpointing. Here we enable binomial checkpointing with storage of a maximum of $10$ forward restart checkpoints in memory:

In [None]:
from firedrake import *
from tlm_adjoint.firedrake import *

import logging

logger = logging.getLogger("tlm_adjoint")
logger.setLevel(logging.DEBUG)
root_logger = logging.getLogger()
if len(logger.handlers) == 1:
 if len(root_logger.handlers) == 1:
 root_logger.handlers.pop()
 root_logger.addHandler(logger.handlers.pop())

reset_manager("memory", {})
clear_caches()

T = 0.1
N = 100
dt = Constant(T / N, static=True)

mesh = UnitSquareMesh(128, 128)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

psi = Function(space, name="psi", static=True)
psi.interpolate(-sin(pi * X[0]) * sin(pi * X[1]))

kappa = Constant(0.01, static=True)

u_0 = Function(space, name="u_0", static=True)
u_0.interpolate(exp(-50.0 * ((X[0] - 0.75) ** 2 + (X[1] - 0.5) ** 2)))


def forward(u_0, psi):
 u_n = Function(space, name="u_n")
 u_np1 = Function(space, name="u_np1")

 u_h = 0.5 * (u_n + trial)
 F = (inner(trial - u_n, test) * dx
 + dt * inner(psi.dx(0) * u_h.dx(1) - psi.dx(1) * u_h.dx(0), test) * dx
 + dt * inner(kappa * grad(u_h), grad(test)) * dx)
 lhs, rhs = system(F)

 eq = EquationSolver(
 lhs == rhs, u_np1,
 solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"})

 u_n.assign(u_0)
 for n in range(N):
 eq.solve()
 u_n.assign(u_np1)
 if n < N - 1:
 new_block()

 J = Functional(name="J")
 J.assign(inner(u_n, u_n) * dx)
 return J


configure_checkpointing("multistage", {"snaps_in_ram": 10, "blocks": N})
start_manager()
J = forward(u_0, psi)
stop_manager()

The key changes here are:

- Configuration of a checkpointing schedule using `configure_checkpointing`. Here binomial checkpointing is applied, with a maximum of $10$ forward restart checkpoints stored in memory, indicated using the `"snaps_in_ram"` parameter. The total number of steps is indicated using the `"blocks"` parameter.
- The indication of the steps using `new_block()`.

Extra logging output is also enabled so that we can see the details of the checkpointing schedule.

### Computing derivatives

We are now ready to compute derivatives. However a key restriction is that we can, with this checkpointing schedule, only perform the adjoint calculation *once* per forward calculation. We cannot call `compute_gradient` a second time, without first rerunning the entire forward calculation.

In the following we compute both first and second derivative information using a single adjoint calculation:

In [None]:
from firedrake import *
from tlm_adjoint.firedrake import *

import logging

logger = logging.getLogger("tlm_adjoint")
logger.setLevel(logging.DEBUG)
root_logger = logging.getLogger()
if len(logger.handlers) == 1:
 if len(root_logger.handlers) == 1:
 root_logger.handlers.pop()
 root_logger.addHandler(logger.handlers.pop())

reset_manager("memory", {})
clear_caches()

T = 0.1
N = 100
dt = Constant(T / N, static=True)

mesh = UnitSquareMesh(128, 128)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

psi = Function(space, name="psi", static=True)
psi.interpolate(-sin(pi * X[0]) * sin(pi * X[1]))

kappa = Constant(0.01, static=True)

u_0 = Function(space, name="u_0", static=True)
u_0.interpolate(exp(-50.0 * ((X[0] - 0.75) ** 2 + (X[1] - 0.5) ** 2)))


def forward(u_0, psi):
 u_n = Function(space, name="u_n")
 u_np1 = Function(space, name="u_np1")

 u_h = 0.5 * (u_n + trial)
 F = (inner(trial - u_n, test) * dx
 + dt * inner(psi.dx(0) * u_h.dx(1) - psi.dx(1) * u_h.dx(0), test) * dx
 + dt * inner(kappa * grad(u_h), grad(test)) * dx)
 lhs, rhs = system(F)

 eq = EquationSolver(
 lhs == rhs, u_np1,
 solver_parameters={"ksp_type": "preonly",
 "pc_type": "lu"})

 u_n.assign(u_0)
 for n in range(N):
 eq.solve()
 u_n.assign(u_np1)
 if n < N - 1:
 new_block()

 J = Functional(name="J")
 J.assign(inner(u_n, u_n) * dx)
 return J


zeta_u_0 = ZeroFunction(space, name="zeta_u_0")
zeta_psi = Function(space, name="zeta_psi", static=True)
zeta_psi.assign(psi)
configure_tlm(((u_0, psi), (zeta_u_0, zeta_psi)))

configure_checkpointing("multistage", {"snaps_in_ram": 10, "blocks": N})
start_manager()
J = forward(u_0, psi)
stop_manager()

dJ_dpsi_zeta = var_tlm(J, ((u_0, psi), (zeta_u_0, zeta_psi)))

dJ_du_0, dJ_dpsi, d2J_dpsi_zeta_du_0 = compute_gradient(
 dJ_dpsi_zeta, (zeta_u_0, zeta_psi, u_0))

The derivative calculation now alternates between forward + tangent-linear calculations, and adjoint calculations.

If we wished we could perform higher order adjoint calculations, using a binomial checkpointing schedule, by supplying a higher order tangent-linear configuration and differentiating the result.