Time-independent example

This notebook describes adjoint calculations, and Hessian calculations, using tlm_adjoint with the Firedrake backend. A time-independent problem is considered, and importantly checkpointing is not used for the adjoint calculations. This notebook further describes how variables may be flagged to facilitate caching.

The high-level algorithmic differentiation approach used by tlm_adjoint is based on the method described in:

  • P. E. Farrell, D. A. Ham, S. W. Funke, and M. E. Rognes, ‘Automated derivation of the adjoint of high-level transient finite element programs’, SIAM Journal on Scientific Computing 35(4), pp. C369–C393, 2013, doi: 10.1137/120873558

The caching of data in tlm_adjoint uses an approach based on that described in:

  • J. R. Maddison and P. E. Farrell, ‘Rapid development and adjoining of transient finite element models’, Computer Methods in Applied Mechanics and Engineering, 276, 95–121, 2014, doi: 10.1016/j.cma.2014.03.010

Forward problem

We consider the solution of a linear time-independent partial differential equation, followed by the calculation of the \(L^2\)-norm of the solution. Extra non-linearity is introduced by allowing the right-hand-side of the partial differential equation to depend non-linearly on the control. We assume real spaces and a real build of Firedrake throughout.

Specifically we consider the solution \(u \in V_0\) of

\[\forall \zeta \in V_0 \qquad \int_\Omega \nabla \zeta \cdot \nabla u = \int_\Omega \zeta m^2,\]

where \(V\) is a real \(P_1\) continuous finite element space defining functions on the domain \(\Omega = \left( 0, 1 \right)^2\), with \(m \in V\), and where \(V_0\) consists of the functions in \(V\) which have zero trace. This corresponds to a discretization of the partial differential equation

\[-\nabla^2 u = m^2 \quad \text{on } \left( x, y \right) \in \left( 0, 1 \right)^2,\]

subject to homogeneous Dirichlet boundary conditions.

A simple implementation in Firedrake, with \(m = x y\), takes the form:

[1]:
from firedrake import *

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])

u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
      DirichletBC(space, 0.0, "on_boundary"))

J_sq = assemble(inner(u, u) * dx)
J = sqrt(J_sq)

Adding tlm_adjoint

We first modify the code so that tlm_adjoint processes the calculations:

[2]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])


def forward(m):
    u = Function(space, name="u")
    solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
          DirichletBC(space, 0.0, "on_boundary"))

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


start_manager()
J = forward(m)
stop_manager()
[2]:
(True, True)

The key changes here are:

  • To import tlm_adjoint with the Firedrake backend.

  • Controlling the ‘manager’ – an object tlm_adjoint uses to process equations.

  • Using a Functional to compute the square of the \(L^2\)-norm of the solution of the (discretized) partial differential equation. This facilitates the calculation of simple functionals e.g. using finite element assembly.

  • Taking the square root of the square of the \(L^2\)-norm using NumPy.

Let’s display the information recorded by tlm_adjoint:

[3]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
  Block 0
    Equation 0, EquationSolver solving for u (id 3)
      Dependency 0, u (id 3), linear
      Dependency 1, m (id 0), non-linear
    Equation 1, Assembly solving for J_sq (id 9)
      Dependency 0, J_sq (id 9), linear
      Dependency 1, u (id 3), non-linear
    Equation 2, FloatEquation solving for f_11 (id 11)
      Dependency 0, f_11 (id 11), linear
      Dependency 1, J_sq (id 9), non-linear
Storage:
  Storing initial conditions: yes
  Storing equation non-linear dependencies: yes
  Initial conditions stored: 2
  Initial conditions referenced: 0
Checkpointing:
  Method: memory

We see that there are three records.

  • Equation 0, an EquationSolver. This records the solution of the finite element variational problem for u.

  • Equation 1, an Assembly. This records the calculation of the square of the \(L^2\)-norm.

  • Equation 2, a FloatEquation. This records the calculation of the square root of the square of the \(L^2\)-norm.

Computing derivatives using an adjoint

The compute_gradient function can be used to compute derivatives using the adjoint method. Here we compute the derivative of the \(L^2\)-norm of the resulting solution, considered a function of the control defined by m, with respect to this control:

[4]:
dJ_dm = compute_gradient(J, m)

Here each degree of freedom associated with dJ_dm contains the derivative of the functional with respect to the corresponding degree of freedom for the control. dJ_dm represents a ‘dual space’ object, defining a linear functional which, given a ‘direction’ \(\zeta \in V\), can be used to compute the directional derivative with respect to \(m\) with direction \(\zeta\).

For example we can compute the directional derivative of the functional with respect to the control \(m\) with direction equal to the unity valued function via:

[5]:
one = Function(space, name="one")
one.interpolate(Constant(1.0))

dJ_dm_one = var_inner(one, dJ_dm)

print(f"{dJ_dm_one=}")
dJ_dm_one=0.020351453243200264

This result is the derivative of the \(L^2\)-norm of the solution with respect to the amplitude of a spatially constant perturbation to the control \(m\). We can compare with the result from finite differencing:

[6]:
def dJ_dm(J, m, *, eps=1.0e-7):
    return (J(m + eps) - J(m - eps)) / (2.0 * eps)


print(f"dJ_dm_one approximation = {dJ_dm(lambda eps: float(forward(m + eps * one)), 0.0)}")
dJ_dm_one approximation = 0.020351453248329543

Computing Hessian information using an adjoint of a tangent-linear

Single direction

We next seek to compute the action of the Hessian of the functional on some direction \(\zeta \in V\), using the adjoint method applied to tangent-linear and forward calculations. This can be handled directly, by configuring the relevant tangent-linear and computing the derivative using compute_gradient:

[7]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])


def forward(m):
    u = Function(space, name="u")
    solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
          DirichletBC(space, 0.0, "on_boundary"))

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))

start_manager()
J = forward(m)
stop_manager()

dJ_dm_zeta = var_tlm(J, (m, zeta))

d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)

The Hessian class applies the same approach, but handles several of the steps for us:

[8]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])


def forward(m):
    u = Function(space, name="u")
    solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
          DirichletBC(space, 0.0, "on_boundary"))

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


H = Hessian(forward)

zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))

_, dJ_dm_zeta, d2J_dm_zeta_dm = H.action(m, zeta)

Multiple directions

If we want to compute the Hessian action on multiple directions we can define multiple tangent-linears:

[9]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])


def forward(m):
    u = Function(space, name="u")
    solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
          DirichletBC(space, 0.0, "on_boundary"))

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


zeta_0 = Function(space, name="zeta_0")
zeta_0.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta_0))

zeta_1 = Function(space, name="zeta_1")
zeta_1.interpolate(sin(pi * X[0]) * sin(2.0 * pi * X[1]))
configure_tlm((m, zeta_1))

start_manager()
J = forward(m)
stop_manager()

dJ_dm_zeta_0 = var_tlm(J, (m, zeta_0))
dJ_dm_zeta_1 = var_tlm(J, (m, zeta_1))

d2J_dm_zeta_0_dm, d2J_dm_zeta_1_dm = compute_gradient((dJ_dm_zeta_0, dJ_dm_zeta_1), m)

There are now calculations for two sets of tangent-linear variables, two sets of first order adjoint variables, and two sets of second order adjoint variables. However the two sets of first order adjoint variables have the same values – by default tlm_adjoint detects this and only computes them once.

The above approach requires us to know the directions before the forward calculation. However some algorithms can generate the directions sequentially, and we do not know the next direction until a Hessian action on the previous direction has been computed. If possible we still want to avoid re-running the forward calculation each time we have a new direction.

If we have sufficient memory available, and in particular so long as we do not need to use checkpointing for the adjoint calculations, we can make use of the CachedHessian class. This stores the forward solution and, by default, caches and reuses first order adjoint values. Here we do not need to configure the tangent-linear before the forward calculation – instead tlm_adjoint performs the tangent-linear calculations after the forward calculations:

[10]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])


def forward(m):
    u = Function(space, name="u")
    solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
          DirichletBC(space, 0.0, "on_boundary"))

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


start_manager()
J = forward(m)
stop_manager()

H = CachedHessian(J)

zeta_0 = Function(space, name="zeta_0")
zeta_0.interpolate(sin(pi * X[0]) * sin(pi * X[1]))

zeta_1 = Function(space, name="zeta_1")
zeta_1.interpolate(sin(pi * X[0]) * sin(2.0 * pi * X[1]))

_, dJ_dm_zeta_0, d2J_dm_zeta_0_dm = H.action(m, zeta_0)
_, dJ_dm_zeta_1, d2J_dm_zeta_1_dm = H.action(m, zeta_1)

Assembly and solver caching

Using an EquationSolver

The calculation for the Hessian action includes four discrete Poisson equations: one for the original forward calculation, one for the tangent-linear calculation, and one each for first and second order adjoint calculations. In this self-adjoint problem the finite element matrix – a stiffness matrix – is the same across all four calculations. Hence we can cache and reuse it. Moreover we can cache and reuse linear solver data – for example we can cache and reuse the Cholesky factorization.

tlm_adjoint can apply such caching automatically, but we must interact directly with the object tlm_adjoint uses to record the solution of finite element variational problems – the EquationSolver previously seen when we used manager_info(). This looks like:

[11]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()
clear_caches()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])


def forward(m):
    u = Function(space, name="u")
    eq = EquationSolver(
        inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
        DirichletBC(space, 0.0, "on_boundary"),
        solver_parameters={"ksp_type": "preonly",
                           "pc_type": "cholesky"})
    eq.solve()

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))

start_manager()
J = forward(m)
stop_manager()

dJ_dm_zeta = var_tlm(J, (m, zeta))

d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)

The key changes here are:

  • The use of clear_caches. This ensures that any previously cached data is cleared, avoiding memory leaks if the code is run more than once.

  • The instantiation of an EquationSolver, and the call to its solve method.

If we query the relevant tlm_adjoint caches we find:

[12]:
print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")

assert len(assembly_cache()) == 1
assert len(linear_solver_cache()) == 1
len(assembly_cache())=1
len(linear_solver_cache())=1

and in particular we see that tlm_adjoint has cached data associated with a single matrix, and has cached a single assembled object (which turns out to be the matrix itself). The latter is a stiffness matrix, and the former stores its Cholesky factorization. The factorization is used four times: in the forward, tangent-linear, and first and second order adjoint calculations.

Flagging data for caching

Now consider the slightly different calculation:

[13]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()
clear_caches()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])

one = Constant(1.0, name="one")


def forward(m):
    u = Function(space, name="u")
    eq = EquationSolver(
        one * inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
        DirichletBC(space, 0.0, "on_boundary"),
        solver_parameters={"ksp_type": "preonly",
                           "pc_type": "cholesky"})
    eq.solve()

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))

start_manager()
J = forward(m)
stop_manager()

dJ_dm_zeta = var_tlm(J, (m, zeta))

d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)

print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")

assert len(assembly_cache()) == 0
assert len(linear_solver_cache()) == 0
len(assembly_cache())=0
len(linear_solver_cache())=0

The only difference is the introduction of the multiplication by one on the left-hand-side of the finite element variational problem. However we now find that no matrix or linear solver data has been cached. The issue is that tlm_adjoint does not know that it should cache the results of calculations involving one.

The ‘cache’ flag

To resolve this, we can flag one for caching using cache=True:

[14]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()
clear_caches()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])

one = Constant(1.0, name="one", cache=True)


def forward(m):
    u = Function(space, name="u")
    eq = EquationSolver(
        one * inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
        DirichletBC(space, 0.0, "on_boundary"),
        solver_parameters={"ksp_type": "preonly",
                           "pc_type": "cholesky"})
    eq.solve()

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))

start_manager()
J = forward(m)
stop_manager()

dJ_dm_zeta = var_tlm(J, (m, zeta))

d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)

print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")

assert len(assembly_cache()) == 2
assert len(linear_solver_cache()) == 1
len(assembly_cache())=2
len(linear_solver_cache())=1

We now see that tlm_adjoint has cached linear solver data associated with a single matrix. However assembly of two objects has been cached – it turns out there are now two cached matrices.

The extra cached matrix appears in the tangent-linear calculations, involving the tangent-linear variable associated with one – a tangent-linear right-hand-side term has been converted into a matrix multiply using a different matrix. However in the above calculation we know that this tangent-linear variable must be zero, since the calculation for one doesn’t depend on the control variable. The extra term in the tangent-linear calculation is similarly also known to be zero.

The ‘static’ flag

We can let tlm_adjoint know that one does not change, and resolve this inefficiency, by instead using static=True:

[15]:
from firedrake import *
from tlm_adjoint.firedrake import *

import numpy as np

reset_manager()
clear_caches()

mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)

space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

m = Function(space, name="m")
m.interpolate(X[0] * X[1])

one = Constant(1.0, name="one", static=True)


def forward(m):
    u = Function(space, name="u")
    eq = EquationSolver(
        one * inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
        DirichletBC(space, 0.0, "on_boundary"),
        solver_parameters={"ksp_type": "preonly",
                           "pc_type": "cholesky"})
    eq.solve()

    J_sq = Functional(name="J_sq")
    J_sq.assign(inner(u, u) * dx)
    J = np.sqrt(J_sq)
    return J


zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))

start_manager()
J = forward(m)
stop_manager()

dJ_dm_zeta = var_tlm(J, (m, zeta))

d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)

print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")

assert len(assembly_cache()) == 1
assert len(linear_solver_cache()) == 1
len(assembly_cache())=1
len(linear_solver_cache())=1

Here static=True leads to one being flagged as a variable whose value is never updated. From this tlm_adjoint can infer that the relevant associated tangent-linear variable is zero, and avoid adding the zero-valued tangent-linear term.

The key difference between using cache=True and static=True is that in the former the value of the variable may be updated. So long as tlm_adjoint is aware of the update (which happens, for example, when tlm_adjoint records a calculation) then updating the value of a variable invalidates cache entries, and invalidated cache entries are cleared automatically.