Time-independent example
This notebook describes adjoint calculations, and Hessian calculations, using tlm_adjoint with the Firedrake backend. A time-independent problem is considered, and importantly checkpointing is not used for the adjoint calculations. This notebook further describes how variables may be flagged to facilitate caching.
The high-level algorithmic differentiation approach used by tlm_adjoint is based on the method described in:
P. E. Farrell, D. A. Ham, S. W. Funke, and M. E. Rognes, ‘Automated derivation of the adjoint of high-level transient finite element programs’, SIAM Journal on Scientific Computing 35(4), pp. C369–C393, 2013, doi: 10.1137/120873558
The caching of data in tlm_adjoint uses an approach based on that described in:
J. R. Maddison and P. E. Farrell, ‘Rapid development and adjoining of transient finite element models’, Computer Methods in Applied Mechanics and Engineering, 276, 95–121, 2014, doi: 10.1016/j.cma.2014.03.010
Forward problem
We consider the solution of a linear time-independent partial differential equation, followed by the calculation of the \(L^2\)-norm of the solution. Extra non-linearity is introduced by allowing the right-hand-side of the partial differential equation to depend non-linearly on the control. We assume real spaces and a real build of Firedrake throughout.
Specifically we consider the solution \(u \in V_0\) of
where \(V\) is a real \(P_1\) continuous finite element space defining functions on the domain \(\Omega = \left( 0, 1 \right)^2\), with \(m \in V\), and where \(V_0\) consists of the functions in \(V\) which have zero trace. This corresponds to a discretization of the partial differential equation
subject to homogeneous Dirichlet boundary conditions.
A simple implementation in Firedrake, with \(m = x y\), takes the form:
[1]:
from firedrake import *
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
J_sq = assemble(inner(u, u) * dx)
J = sqrt(J_sq)
Adding tlm_adjoint
We first modify the code so that tlm_adjoint processes the calculations:
[2]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
def forward(m):
u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
start_manager()
J = forward(m)
stop_manager()
[2]:
(True, True)
The key changes here are:
To import tlm_adjoint with the Firedrake backend.
Controlling the ‘manager’ – an object tlm_adjoint uses to process equations.
Using a
Functional
to compute the square of the \(L^2\)-norm of the solution of the (discretized) partial differential equation. This facilitates the calculation of simple functionals e.g. using finite element assembly.Taking the square root of the square of the \(L^2\)-norm using NumPy.
Let’s display the information recorded by tlm_adjoint:
[3]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
Block 0
Equation 0, EquationSolver solving for u (id 6)
Dependency 0, u (id 6), linear
Dependency 1, m (id 1), non-linear
Equation 1, Assembly solving for J_sq (id 12)
Dependency 0, J_sq (id 12), linear
Dependency 1, u (id 6), non-linear
Equation 2, FloatEquation solving for f_14 (id 14)
Dependency 0, f_14 (id 14), linear
Dependency 1, J_sq (id 12), non-linear
Storage:
Storing initial conditions: yes
Storing equation non-linear dependencies: yes
Initial conditions stored: 1
Initial conditions referenced: 0
Checkpointing:
Method: memory
We see that there are three records.
Equation 0, an
EquationSolver
. This records the solution of the finite element variational problem foru
.Equation 1, an
Assembly
. This records the calculation of the square of the \(L^2\)-norm.Equation 2, a
FloatEquation
. This records the calculation of the square root of the square of the \(L^2\)-norm.
Computing derivatives using an adjoint
The compute_gradient
function can be used to compute derivatives using the adjoint method. Here we compute the derivative of the \(L^2\)-norm of the resulting solution, considered a function of the control defined by m
, with respect to this control:
[4]:
dJ_dm = compute_gradient(J, m)
Here each degree of freedom associated with dJ_dm
contains the derivative of the functional with respect to the corresponding degree of freedom for the control. dJ_dm
represents a ‘dual space’ object, defining a linear functional which, given a ‘direction’ \(\zeta \in V\), can be used to compute the directional derivative with respect to \(m\) with direction \(\zeta\).
For example we can compute the directional derivative of the functional with respect to the control \(m\) with direction equal to the unity valued function via:
[5]:
one = Function(space, name="one")
one.interpolate(Constant(1.0))
dJ_dm_one = assemble(action(dJ_dm, one))
print(f"{dJ_dm_one=}")
dJ_dm_one=0.020351453243200264
This result is the derivative of the \(L^2\)-norm of the solution with respect to the amplitude of a spatially constant perturbation to the control \(m\). We can compare with the result from finite differencing:
[6]:
def dJ_dm(J, m, *, eps=1.0e-7):
return (J(m + eps) - J(m - eps)) / (2.0 * eps)
print(f"dJ_dm_one approximation = {dJ_dm(lambda eps: float(forward(m + eps * one)), 0.0)}")
dJ_dm_one approximation = 0.020351453248329543
Computing Hessian information using an adjoint of a tangent-linear
Single direction
We next seek to compute the action of the Hessian of the functional on some direction \(\zeta \in V\), using the adjoint method applied to tangent-linear and forward calculations. This can be handled directly, by configuring the relevant tangent-linear and computing the derivative using compute_gradient
:
[7]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
def forward(m):
u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))
start_manager()
J = forward(m)
stop_manager()
dJ_dm_zeta = var_tlm(J, (m, zeta))
d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)
The Hessian
class applies the same approach, but handles several of the steps for us:
[8]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
def forward(m):
u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
H = Hessian(forward)
zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
_, dJ_dm_zeta, d2J_dm_zeta_dm = H.action(m, zeta)
Multiple directions
If we want to compute the Hessian action on multiple directions we can define multiple tangent-linears:
[9]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
def forward(m):
u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
zeta_0 = Function(space, name="zeta_0")
zeta_0.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta_0))
zeta_1 = Function(space, name="zeta_1")
zeta_1.interpolate(sin(pi * X[0]) * sin(2.0 * pi * X[1]))
configure_tlm((m, zeta_1))
start_manager()
J = forward(m)
stop_manager()
dJ_dm_zeta_0 = var_tlm(J, (m, zeta_0))
dJ_dm_zeta_1 = var_tlm(J, (m, zeta_1))
d2J_dm_zeta_0_dm, d2J_dm_zeta_1_dm = compute_gradient((dJ_dm_zeta_0, dJ_dm_zeta_1), m)
There are now calculations for two sets of tangent-linear variables, two sets of first order adjoint variables, and two sets of second order adjoint variables. However the two sets of first order adjoint variables have the same values – by default tlm_adjoint detects this and only computes them once.
The above approach requires us to know the directions before the forward calculation. However some algorithms can generate the directions sequentially, and we do not know the next direction until a Hessian action on the previous direction has been computed. If possible we still want to avoid re-running the forward calculation each time we have a new direction.
If we have sufficient memory available, and in particular so long as we do not need to use checkpointing for the adjoint calculations, we can make use of the CachedHessian
class. This stores the forward solution and, by default, caches and reuses first order adjoint values. Here we do not need to configure the tangent-linear before the forward calculation – instead tlm_adjoint performs the tangent-linear calculations after the forward calculations:
[10]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
def forward(m):
u = Function(space, name="u")
solve(inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"))
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
start_manager()
J = forward(m)
stop_manager()
H = CachedHessian(J)
zeta_0 = Function(space, name="zeta_0")
zeta_0.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
zeta_1 = Function(space, name="zeta_1")
zeta_1.interpolate(sin(pi * X[0]) * sin(2.0 * pi * X[1]))
_, dJ_dm_zeta_0, d2J_dm_zeta_0_dm = H.action(m, zeta_0)
_, dJ_dm_zeta_1, d2J_dm_zeta_1_dm = H.action(m, zeta_1)
Assembly and solver caching
Using an EquationSolver
The calculation for the Hessian action includes four discrete Poisson equations: one for the original forward calculation, one for the tangent-linear calculation, and one each for first and second order adjoint calculations. In this self-adjoint problem the finite element matrix – a stiffness matrix – is the same across all four calculations. Hence we can cache and reuse it. Moreover we can cache and reuse linear solver data – for example we can cache and reuse the Cholesky factorization.
tlm_adjoint can apply such caching automatically, but we must interact directly with the object tlm_adjoint uses to record the solution of finite element variational problems – the EquationSolver
previously seen when we used manager_info()
. This looks like:
[11]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
clear_caches()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
def forward(m):
u = Function(space, name="u")
eq = EquationSolver(
inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"),
solver_parameters={"ksp_type": "preonly",
"pc_type": "cholesky"})
eq.solve()
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))
start_manager()
J = forward(m)
stop_manager()
dJ_dm_zeta = var_tlm(J, (m, zeta))
d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)
The key changes here are:
The use of
clear_caches
. This ensures that any previously cached data is cleared, avoiding memory leaks if the code is run more than once.The instantiation of an
EquationSolver
, and the call to itssolve
method.
If we query the relevant tlm_adjoint caches we find:
[12]:
print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")
assert len(assembly_cache()) == 1
assert len(linear_solver_cache()) == 1
len(assembly_cache())=1
len(linear_solver_cache())=1
and in particular we see that tlm_adjoint has cached data associated with a single matrix, and has cached a single assembled object (which turns out to be the matrix itself). The latter is a stiffness matrix, and the former stores its Cholesky factorization. The factorization is used four times: in the forward, tangent-linear, and first and second order adjoint calculations.
Flagging data for caching
Now consider the slightly different calculation:
[13]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
clear_caches()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
one = Constant(1.0, name="one")
def forward(m):
u = Function(space, name="u")
eq = EquationSolver(
one * inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"),
solver_parameters={"ksp_type": "preonly",
"pc_type": "cholesky"})
eq.solve()
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))
start_manager()
J = forward(m)
stop_manager()
dJ_dm_zeta = var_tlm(J, (m, zeta))
d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)
print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")
assert len(assembly_cache()) == 0
assert len(linear_solver_cache()) == 0
len(assembly_cache())=0
len(linear_solver_cache())=0
The only difference is the introduction of the multiplication by one
on the left-hand-side of the finite element variational problem. However we now find that no matrix or linear solver data has been cached. The issue is that tlm_adjoint does not know that it should cache the results of calculations involving one
.
The ‘cache’ flag
To resolve this, we can flag one
for caching using cache=True
:
[14]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
clear_caches()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
one = Constant(1.0, name="one", cache=True)
def forward(m):
u = Function(space, name="u")
eq = EquationSolver(
one * inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"),
solver_parameters={"ksp_type": "preonly",
"pc_type": "cholesky"})
eq.solve()
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))
start_manager()
J = forward(m)
stop_manager()
dJ_dm_zeta = var_tlm(J, (m, zeta))
d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)
print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")
assert len(assembly_cache()) == 2
assert len(linear_solver_cache()) == 1
len(assembly_cache())=2
len(linear_solver_cache())=1
We now see that tlm_adjoint has cached linear solver data associated with a single matrix. However assembly of two objects has been cached – it turns out there are now two cached matrices.
The extra cached matrix appears in the tangent-linear calculations, involving the tangent-linear variable associated with one
– a tangent-linear right-hand-side term has been converted into a matrix multiply using a different matrix. However in the above calculation we know that this tangent-linear variable must be zero, since the calculation for one
doesn’t depend on the control variable. The extra term in the tangent-linear calculation is similarly also known to be zero.
The ‘static’ flag
We can let tlm_adjoint know that one
does not change, and resolve this inefficiency, by instead using static=True
:
[15]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
reset_manager()
clear_caches()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
m = Function(space, name="m")
m.interpolate(X[0] * X[1])
one = Constant(1.0, name="one", static=True)
def forward(m):
u = Function(space, name="u")
eq = EquationSolver(
one * inner(grad(trial), grad(test)) * dx == inner(m * m, test) * dx, u,
DirichletBC(space, 0.0, "on_boundary"),
solver_parameters={"ksp_type": "preonly",
"pc_type": "cholesky"})
eq.solve()
J_sq = Functional(name="J_sq")
J_sq.assign(inner(u, u) * dx)
J = np.sqrt(J_sq)
return J
zeta = Function(space, name="zeta")
zeta.interpolate(sin(pi * X[0]) * sin(pi * X[1]))
configure_tlm((m, zeta))
start_manager()
J = forward(m)
stop_manager()
dJ_dm_zeta = var_tlm(J, (m, zeta))
d2J_dm_zeta_dm = compute_gradient(dJ_dm_zeta, m)
print(f"{len(assembly_cache())=}")
print(f"{len(linear_solver_cache())=}")
assert len(assembly_cache()) == 1
assert len(linear_solver_cache()) == 1
len(assembly_cache())=1
len(linear_solver_cache())=1
Here static=True
leads to one
being flagged as a variable whose value is never updated. From this tlm_adjoint can infer that the relevant associated tangent-linear variable is zero, and avoid adding the zero-valued tangent-linear term.
The key difference between using cache=True
and static=True
is that in the former the value of the variable may be updated. So long as tlm_adjoint is aware of the update (which happens, for example, when tlm_adjoint records a calculation) then updating the value of a variable invalidates cache entries, and invalidated cache entries are cleared automatically.