Defining custom operations
This notebook describes the tlm_adjoint ‘escape hatch’ mechanism, which allows custom operations to be defined, recorded by the internal manager, and used in higher order derivative calculations.
The approach used by tlm_adjoint to define custom operations is described in
James R. Maddison, Daniel N. Goldberg, and Benjamin D. Goddard, ‘Automated calculation of higher order partial differential equation constrained derivative information’, SIAM Journal on Scientific Computing, 41(5), pp. C417–C445, 2019, doi: 10.1137/18M1209465
We assume real spaces and a real build of Firedrake throughout.
Forward problem
We consider the Poisson equation in the unit square domain
with
subject to boundary conditions
where \(\hat{n}\) is an outward unit normal on the boundary \(\partial \Omega\) of the domain \(\Omega\).
We consider a continuous Galerkin discretization, now seeking \(u \in V\) such that
where \(V\) is a real \(P_1\) continuous finite element space defining functions on the domain \(\Omega = \left( 0, 1 \right)^2\). \(m \in V\) is now defined via
where \(\mathcal{I}\) maps to an element of \(V\) through interpolation at mesh vertices, and where \(k\) is defined so that \(\int_\Omega m = 0\).
As stated the discrete problem does not have a unique solution – given any solution \(u_0\) to the problem we can define a new solution \(u_1 = u_0 + c_1\) for any scalar \(c_1\). We need to identify a solution by supplying an appropriate constraint. We will use the constraint
where the \(\tilde{u}_i\) are the degrees of freedom associated with \(u \in V\).
An implementation using Firedrake, solving for \(u\) and then computing \(\frac{1}{2} \int_\Omega u^2\), takes the form
[1]:
%matplotlib inline
from firedrake import *
from firedrake.pyplot import tricontourf
import matplotlib.pyplot as plt
import numpy as np
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test, trial = TestFunction(space), TrialFunction(space)
def forward(m):
k = Constant(assemble(m * dx) / assemble(Constant(1.0) * dx(mesh)))
m_tilde = Function(m.function_space(), name="m_tilde").interpolate(m - k)
u = Function(space, name="u")
nullspace = VectorSpaceBasis(comm=u.comm, constant=True)
solve(inner(grad(trial), grad(test)) * dx == -inner(m_tilde, test) * dx,
u, nullspace=nullspace, transpose_nullspace=nullspace,
solver_parameters={"ksp_type": "cg", "pc_type": "hypre",
"pc_hypre_type": "boomeramg",
"ksp_rtol": 1.0e-10, "ksp_atol": 1.0e-16})
J = assemble(0.5 * inner(u, u) * dx)
return u, J
m = Function(space, name="m").interpolate(cos(pi * X[0]) * cos(pi * X[1]))
u, J = forward(m)
def plot_output(u, title):
r = (u.dat.data_ro.min(), u.dat.data_ro.max())
eps = (r[1] - r[0]) * 1.0e-12
p = tricontourf(u, np.linspace(r[0] - eps, r[1] + eps, 32))
plt.gca().set_title(title)
plt.colorbar(p)
plt.gca().set_aspect(1.0)
plot_output(u, title="u")
/home/maddison/build/firedrake/firedrake/lib/python3.10/site-packages/pytools/persistent_dict.py:59: UserWarning: Unable to import recommended hash 'siphash24.siphash13', falling back to 'hashlib.sha256'. Run 'python3 -m pip install siphash24' to install the recommended hash.
warn("Unable to import recommended hash 'siphash24.siphash13', "
Adding tlm_adjoint
Next we add the use of tlm_adjoint, which requires some minor modification of the forward code.
compute_gradient
is used to compute the derivative of \(\frac{1}{2} \int_\Omega u^2\) with respect to \(m\), subject to the contraint that the forward problem is solved. The derivative is visualized using a Riesz map defined by the \(L^2\) inner product. A Taylor remainder convergence test is used to verify the derivative.
Note that, given an arbitrary \(m \in V\), we define
so that \(\int_\Omega \tilde{m} = 0\), and use \(\tilde{m}\) in place of \(m\) when solving the discrete Poisson problem.
[2]:
%matplotlib inline
from firedrake import *
from tlm_adjoint.firedrake import *
from firedrake.pyplot import tricontourf
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(54151610)
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test, trial = TestFunction(space), TrialFunction(space)
def forward(m):
v = Constant(name="v")
Assembly(v, Constant(1.0) * dx(mesh)).solve()
k = Constant(name="k")
Assembly(k, (m / v) * dx).solve()
m_tilde = Function(m.function_space(), name="m_tilde").interpolate(m - k)
u = Function(space, name="u")
nullspace = VectorSpaceBasis(comm=u.comm, constant=True)
solve(inner(grad(trial), grad(test)) * dx == -inner(m_tilde, test) * dx,
u, nullspace=nullspace, transpose_nullspace=nullspace,
solver_parameters={"ksp_type": "cg", "pc_type": "hypre",
"pc_hypre_type": "boomeramg",
"ksp_rtol": 1.0e-10, "ksp_atol": 1.0e-16})
J = Functional(name="J")
J.assign(0.5 * inner(u, u) * dx)
return u, J
m = Function(space, name="m").interpolate(cos(pi * X[0]) * cos(pi * X[1]))
start_manager()
u, J = forward(m)
stop_manager()
dJ = compute_gradient(J, m)
def plot_output(u, title):
r = (u.dat.data_ro.min(), u.dat.data_ro.max())
eps = (r[1] - r[0]) * 1.0e-12
p = tricontourf(u, np.linspace(r[0] - eps, r[1] + eps, 32))
plt.gca().set_title(title)
plt.colorbar(p)
plt.gca().set_aspect(1.0)
plot_output(u, title="u")
plot_output(dJ.riesz_representation("L2"), title=r"$g^\sharp$")
def forward_J(m):
_, J = forward(m)
return J
min_order = taylor_test(forward_J, m, J_val=J.value, dJ=dJ)
assert min_order > 1.99
Error norms, no adjoint = [6.16667546e-08 3.06309860e-08 1.52648952e-08 7.61979814e-09
3.80673670e-09]
Orders, no adjoint = [1.00950111 1.00477412 1.002393 1.00119799]
Error norms, with adjoint = [8.09565105e-10 2.02391272e-10 5.05978163e-11 1.26494529e-11
3.16236247e-12]
Orders, with adjoint = [2.00000003 2.00000005 2.00000013 2.00000034]
Defining a custom Equation
Since the \(u\) we have computed should satisfy
we should have that the derivative of this quantity with respect to \(m\), subject to the forward problem being solved, is zero. We will compute this derivative by defining a custom operation which sums the degrees of freedom of a finite element discretized function. We can do this by inheriting from the Equation
class and implementing appropriate methods.
[3]:
from tlm_adjoint import Equation
class Sum(Equation):
def __init__(self, x, y):
super().__init__(x, deps=(x, y), nl_deps=(), ic=False, adj_ic=False,
adj_type="conjugate_dual")
def forward_solve(self, x, deps=None):
if deps is None:
deps = self.dependencies()
_, y = deps
with y.dat.vec_ro as y_v:
y_sum = y_v.sum()
x.assign(y_sum)
def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
return b
def adjoint_derivative_action(self, nl_deps, dep_index, adj_x):
if dep_index != 1:
raise ValueError("Unexpected dep_index")
_, y = self.dependencies()
b = Cofunction(var_space(y).dual())
adj_x = float(adj_x)
b.dat.data[:] = -adj_x
return b
def tangent_linear(self, tlm_map):
tau_x, tau_y = (tlm_map[dep] for dep in self.dependencies())
if tau_y is None:
return ZeroAssignment(tau_x)
else:
return Sum(tau_x, tau_y)
We’ll investgate how this is put together shortly.
First, we check that the Sum
class works. In the following we check that we can compute the sum of the degrees of freedom of a function, and then use Taylor remainder convergence tests to verify a first order tangent-linear calculation, a first order adjoint calculation, and a reverse-over-forward adjoint calculation of a Hessian action.
[4]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
np.random.seed(13561700)
reset_manager()
def forward(m):
m_sum = Float(name="m_sum")
Sum(m_sum, m).solve()
J = m_sum ** 3
return m_sum, J
def forward_J(m):
_, J = forward(m)
return J
mesh = UnitIntervalMesh(10)
x, = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Discontinuous Lagrange", 0)
m = Function(space, name="m")
m.interpolate(exp(-x) * sin(pi * x))
with m.dat.vec_ro as m_v:
m_sum_ref = m_v.sum()
start_manager()
m_sum, J = forward(m)
stop_manager()
# Verify the forward calculation
print(f"{float(m_sum)=}")
print(f"{m_sum_ref=}")
assert abs(float(m_sum) - m_sum_ref) == 0.0
# Verify tangent-linear and adjoint calculations
min_order = taylor_test_tlm(forward_J, m, tlm_order=1)
assert min_order > 1.99
min_order = taylor_test_tlm_adjoint(forward_J, m, adjoint_order=1)
assert min_order > 1.99
min_order = taylor_test_tlm_adjoint(forward_J, m, adjoint_order=2)
assert min_order > 1.99
float(m_sum)=3.97146153200585
m_sum_ref=3.97146153200585
Error norms, no tangent-linear = [1.4597981 0.72709661 0.36284904 0.18124987 0.09058129]
Orders, no tangent-linear = [1.00554988 1.00277762 1.00138948 1.00069491]
Error norms, with tangent-linear = [1.11954288e-02 2.79527053e-03 6.98369299e-04 1.74536283e-04
4.36270656e-05]
Orders, with tangent-linear = [2.00184997 2.00092587 2.00046316 2.00023164]
Error norms, no adjoint = [1.68516173 0.83885294 0.41849655 0.20901605 0.10445 ]
Orders, no adjoint = [1.00639725 1.00320219 1.00160198 1.00080121]
Error norms, with adjoint = [1.48897315e-02 3.71693487e-03 9.28546465e-04 2.32050710e-04
5.80019391e-05]
Orders, with adjoint = [2.00213243 2.0010674 2.00053399 2.00026707]
Error norms, no adjoint = [3.70748562 1.85093251 0.92476368 0.46220619 0.23105919]
Orders, no adjoint = [1.00218881 1.00109565 1.00054814 1.00027415]
Error norms, with adjoint = [1.12412164e-02 2.81030409e-03 7.02576023e-04 1.75644006e-04
4.39110014e-05]
Orders, with adjoint = [2. 2. 2. 2.]
We can now use Sum
with the Poisson solver. We find, as expected, that the derivative of the sum of the degrees of freedom for \(u\) with respect to \(m\), subject to the forward problem being solved, is zero.
[5]:
%matplotlib inline
from firedrake import *
from tlm_adjoint.firedrake import *
import matplotlib.pyplot as plt
import numpy as np
reset_manager()
mesh = UnitSquareMesh(10, 10)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test, trial = TestFunction(space), TrialFunction(space)
def forward(m):
v = Constant(name="v")
Assembly(v, Constant(1.0) * dx(mesh)).solve()
k = Constant(name="k")
Assembly(k, (m / v) * dx).solve()
m_tilde = Function(m.function_space(), name="m_tilde").interpolate(m - k)
u = Function(space, name="u")
nullspace = VectorSpaceBasis(comm=u.comm, constant=True)
solve(inner(grad(trial), grad(test)) * dx == -inner(m_tilde, test) * dx,
u, nullspace=nullspace, transpose_nullspace=nullspace,
solver_parameters={"ksp_type": "cg", "pc_type": "hypre",
"pc_hypre_type": "boomeramg",
"ksp_rtol": 1.0e-10, "ksp_atol": 1.0e-16})
J = Functional(name="J")
J.assign(0.5 * inner(u, u) * dx)
u_sum = Constant(name="u_sum")
Sum(u_sum, u).solve()
K = Functional(name="K")
K.assign(u_sum)
return u, J, K
m = Function(space, name="m").interpolate(cos(pi * X[0]) * cos(pi * X[1]))
start_manager()
u, J, K = forward(m)
stop_manager()
dJ, dK = compute_gradient((J, K), m)
dK_norm = abs(dK.dat.data_ro).max()
print(f"{dK_norm=}")
assert dK_norm == 0.0
dK_norm=np.float64(0.0)
Equation
methods
We now return to the definition of the Sum
class, and consider each method in turn.
__init__
def __init__(self, x, y):
super().__init__(x, deps=(x, y), nl_deps=(), ic=False, adj_ic=False,
adj_type="conjugate_dual")
The arguments passed to the base class constructor are:
The output of the forward operation,
x
.deps
: Defines all inputs and outputs of the forward operation.nl_deps
: Defines elements ofdeps
on which the associated adjoint calculations can depend.ic
: Whether the forward operation accepts a non-zero ‘initial guess’.adj_ic
: Whether an adjoint solve accepts a non-zero ‘initial guess’.adj_type
: Either"primal"
or"conjugate_dual"
defining the space for an associated adjoint variable.
For this operation there is one output x
and one input y
. The operation is defined by the forward residual function
Given a value for \(y\) the output of the operation is defined to be the \(x\) for which \(F \left( x, y \right) = 0\). \(F\) depends linearly on both \(x\) and \(y\), and so associated adjoint calculations are independent of both \(x\) and \(y\). Hence we set deps=(x, y)
and nl_deps=()
.
ic
and adj_ic
are used to indicate whether non-zero initial guesses should be supplied to forward and adjoint iterative solvers respectively. Here we do not use iterative solvers, and so no non-zero initial guesses are needed. Hence we set ic=False
and adj_ic=False
.
adj_type
indicates whether an adjoint variable associated with the operation is:
In the same space as
x
. This is the case where \(F\) maps to an element in the antidual space associated with \(x\), and is indicated withadj_x_type="primal"
. A typical example of this case is an operation defined via the solution of a finite element variational problem.In the antidual space associated with the space for
x
. This is the case where \(F\) maps to an element in same space as \(x\), and is indiciated withadj_x_type="conjugate_dual"
.
forward_solve
def forward_solve(self, x, deps=None):
if deps is None:
deps = self.dependencies()
_, y = deps
with y.dat.vec_ro as y_v:
y_sum = y_v.sum()
x.assign(y_sum)
This method computes the output of the forward operation, storing the result in x
. deps
is used in rerunning the forward calculation when making use of a checkpointing schedule. If supplied, deps
contains values associated with the dependencies as defined in the constructor (defined by the deps
argument passed to the base class constructor). Otherwise the values contained in self.dependencies()
are used.
adjoint_jacobian_solve
def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
return b
Solves a linear problem
for the adjoint solution \(\lambda\), returning the result. The right-hand-side \(b\) is defined by the argument b
. In this example \(\partial F / \partial x\) is simply the identity, and so \(\lambda = b\).
adj_x
can contain an initial guess for an iterative solver used to compute \(\lambda\). nl_deps
contains values associated with the forward dependencies of the adjoint, as defined in the constructor (defined by the nl_deps
argument passed to the base class constructor).
adjoint_derivative_action
def adjoint_derivative_action(self, nl_deps, dep_index, adj_x):
if dep_index != 1:
raise ValueError("Unexpected dep_index")
_, y = self.dependencies()
b = Cofunction(var_space(y).dual())
adj_x = float(adj_x)
b.dat.data[:] = -adj_x
return b
Computes
where \(v\) is the dep_index
th dependency (for the dependencies defined by the deps
argument passed to the base class constructor). Again nl_deps
contains values associated with the forward dependencies of the adjoint, as defined in the constructor.
tangent_linear
def tangent_linear(self, tlm_map):
tau_x, tau_y = (tlm_map[dep] for dep in self.dependencies())
if tau_y is None:
return ZeroAssignment(tau_x)
else:
return Sum(tau_x, tau_y)
Constructs a tangent-linear operation, for a tangent-linear computing directional derivatives with respect to the control defined by M
with direction defined by dM
. tlm_map
is used to access tangent-linear variables.
The new operation is defined by the residual function
where \(\tau_x\) and \(\tau_y\) are the tangent-linear variables associated with \(x\) and \(y\) respectively. Given a value for \(\tau_y\) the output of the operation is defined to be the \(\tau_x\) for which \(F_\tau \left( \tau_x, \tau_y \right) = 0\).
Note that this method returns an Equation
instance – here either a ZeroAssignment
, for the case where the tangent-linear operation sets a tangent-linear variable equal to zero, or another Sum
. This new operation can then be recorded by the internal manager, so that reverse-over-forward algorithmic differentiation can be applied.
Reference dropping
An important method, not defined in this example, is the drop_references
method. This method is used to drop references to forward data. For example if an Equation
subclass defines a Function
attribute _function
, then the drop_references
method might look like
def drop_references(self):
super().drop_references()
self._function = var_replacement(self._function)
Here var_replacement
returns a ‘symbolic only’ version of self._function
, which can for example be used in UFL expressions, but which has no associated value.
Defining custom operations using JAX
In some cases it may be more convenient to directly implement an operation using lower level code. tlm_adjoint integrates with JAX to allow this. For example a Firedrake Function
can be converted to a JAX array using
v = to_jax(x)
and the result converted back to a Firedrake Function
using
x = to_firedrake(v, space)
Here we use JAX to compute the sum of the degrees of freedom for a Firedrake function (assuming a serial calculation).
[6]:
from firedrake import *
from tlm_adjoint.firedrake import *
import numpy as np
np.random.seed(81976463)
reset_manager()
def forward(m):
m_sum = new_jax_float(name="m_sum")
call_jax(m_sum, to_jax(m), lambda v: v.sum())
J = m_sum ** 3
return m_sum, J
def forward_J(m):
_, J = forward(m)
return J
mesh = UnitIntervalMesh(10)
x, = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Discontinuous Lagrange", 0)
m = Function(space, name="m")
m.interpolate(exp(-x) * sin(pi * x))
with m.dat.vec_ro as m_v:
m_sum_ref = m_v.sum()
start_manager()
m_sum, J = forward(m)
stop_manager()
# Verify the forward calculation
print(f"{float(m_sum)=}")
print(f"{m_sum_ref=}")
assert abs(float(m_sum) - m_sum_ref) < 1.0e-15
# Verify tangent-linear and adjoint calculations
min_order = taylor_test_tlm(forward_J, m, tlm_order=1)
assert min_order > 1.99
min_order = taylor_test_tlm_adjoint(forward_J, m, adjoint_order=1)
assert min_order > 1.99
min_order = taylor_test_tlm_adjoint(forward_J, m, adjoint_order=2)
assert min_order > 1.99
float(m_sum)=3.9714615320058506
m_sum_ref=3.97146153200585
Error norms, no tangent-linear = [1.90883165 0.94964096 0.47362975 0.23651757 0.1181845 ]
Orders, no tangent-linear = [1.0072358 1.00362246 1.00181237 1.00090647]
Error norms, with tangent-linear = [1.90676066e-02 4.75893883e-03 1.18873935e-03 2.97060420e-04
7.42495525e-05]
Orders, with tangent-linear = [2.00241195 2.00120749 2.00060412 2.00030216]
Error norms, no adjoint = [1.34856474 0.67188864 0.33534695 0.16752427 0.08372485]
Orders, no adjoint = [1.00513073 1.00256765 1.0012844 1.00064234]
Error norms, with adjoint = [9.56356602e-03 2.38805889e-03 5.96660647e-04 1.49120902e-04
3.72746931e-05]
Orders, with adjoint = [2.00171025 2.00085588 2.00042813 2.00021411]
Error norms, no adjoint = [4.47108561 2.23061531 1.11407578 0.55672992 0.27828797]
Orders, no adjoint = [1.00318344 1.00159436 1.00079784 1.00039909]
Error norms, with adjoint = [1.97099742e-02 4.92749355e-03 1.23187339e-03 3.07968347e-04
7.69920868e-05]
Orders, with adjoint = [2. 2. 2. 2.]