Getting started with tlm_adjoint
This notebook introduces derivative calculations using tlm_adjoint.
tlm_adjoint is primarily intended for first derivative calculations using the adjoint method, and Hessian information calculations using the adjoint method applied to tangent-linear and forward calculations. However the approach used by tlm_adjoint generalizes to higher order.
The approach used by tlm_adjoint for higher order differentiation is described in:
James R. Maddison, Daniel N. Goldberg, and Benjamin D. Goddard, ‘Automated calculation of higher order partial differential equation constrained derivative information’, SIAM Journal on Scientific Computing, 41(5), pp. C417–C445, 2019, doi: 10.1137/18M1209465
A floating point example
We consider a simple floating point calculation:
[1]:
import numpy as np
x = 1.0
y = 0.25 * np.pi
z = x * np.sin(y * np.exp(y))
tlm_adjoint is designed for high-level algorithmic differentiation, and not this type of low-level floating point calculation. However it can still process simple floating point calculations, so to introduce the key ideas we do that here. We consider differentiating z
with respect to x
and y
– being precise, we mean computing the derivative of the function used to compute z
with respect to the variables defined by x
and y
.
Adding tlm_adjoint
We first modify the code so that tlm_adjoint processes the calculations:
[2]:
from tlm_adjoint import *
import numpy as np
reset_manager()
def forward(x, y):
z = x * np.sin(y * np.exp(y))
return z
x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")
start_manager()
z = forward(x, y)
stop_manager()
/home/maddison/build/firedrake/firedrake/lib/python3.10/site-packages/pytools/persistent_dict.py:59: UserWarning: Unable to import recommended hash 'siphash24.siphash13', falling back to 'hashlib.sha256'. Run 'python3 -m pip install siphash24' to install the recommended hash.
warn("Unable to import recommended hash 'siphash24.siphash13', "
[2]:
(True, True)
The key changes here are:
To import tlm_adjoint.
Controlling the ‘manager’ – an object tlm_adjoint uses to process equations. The manager is first reset using
reset_manager
. This clears any previous processing, and disables the manager.start_manager
andstop_manager
are then used to enable the manager just when it is needed.Defining
x
andy
to be of typeFloat
. Calculations involvingx
andy
are recorded by the manager. The result of the calculations – herez
– will have the same type, and we can access its value withfloat(z)
.
Let’s display the information recorded by tlm_adjoint:
[3]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
Block 0
Equation 0, FloatEquation solving for f_2 (id 2)
Dependency 0, f_2 (id 2), linear
Dependency 1, y (id 1), non-linear
Equation 1, FloatEquation solving for f_4 (id 4)
Dependency 0, f_4 (id 4), linear
Dependency 1, y (id 1), non-linear
Dependency 2, f_2 (id 2), non-linear
Equation 2, FloatEquation solving for f_6 (id 6)
Dependency 0, f_6 (id 6), linear
Dependency 1, f_4 (id 4), non-linear
Equation 3, FloatEquation solving for f_8 (id 8)
Dependency 0, f_8 (id 8), linear
Dependency 1, x (id 0), non-linear
Dependency 2, f_6 (id 6), non-linear
Storage:
Storing initial conditions: yes
Storing equation non-linear dependencies: yes
Initial conditions stored: 2
Initial conditions referenced: 0
Checkpointing:
Method: memory
The key feature here is that there are four FloatEquation
records, corresponding to the four floating point calculations – evaluation using np.exp
and np.sin
, and two multiplications.
Computing derivatives using an adjoint
The compute_gradient
function can be used to differentiate z
with respect to x
and y
using the adjoint method:
[4]:
dz_dx, dz_dy = compute_gradient(z, (x, y))
print(f"{float(dz_dx)=}")
print(f"{float(dz_dy)=}")
float(dz_dx)=0.9885002159138745
float(dz_dy)=-0.592156957782821
For a simple check of the result, we can compare with the result from finite differencing, here using second order centered finite differencing:
[5]:
def dJ_dm(J, m, *, eps=1.0e-7):
return (J(m + eps) - J(m - eps)) / (2.0 * eps)
print(f"dz_dx approximation = {dJ_dm(lambda x: forward(x, float(y)), float(x))}")
print(f"dz_dy approximation = {dJ_dm(lambda y: forward(float(x), y), float(y))}")
dz_dx approximation = 0.9885002155707312
dz_dy approximation = -0.5921569579125929
Computing derivatives using a tangent-linear
A tangent-linear computes directional derivatives with respect to a given control and with a given direction.
Here we consider the forward to be a function of x
and y
, computing a value of z
, denoted \(z \left( x, y \right)\). We consider the control \(m = \left( x, y \right)^T\), and a direction \(\zeta = \left( 2, 3 \right)^T\). We can then use a tangent-linear to compute the directional derivative
where vector derivatives are notated using row vectors.
In most cases tlm_adjoint needs to be told what tangent-linear calculations to perform ahead of the forward calculations. configure_tlm
provides this information to tlm_adjoint:
[6]:
from tlm_adjoint import *
import numpy as np
reset_manager()
def forward(x, y):
z = x * np.sin(y * np.exp(y))
return z
x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")
m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
configure_tlm((m, zeta))
start_manager()
z = forward(x, y)
stop_manager()
dz_dm_zeta = var_tlm(z, (m, zeta))
print(f"{float(dz_dm_zeta)=}")
float(dz_dm_zeta)=0.2005295584792861
There are three new changes:
The control \(m\) and direction \(\zeta\) are defined using
m
andzeta
respectively.Tangent-linear calculations are configured before the forward calculations are performed, using
configure_tlm
. Note that here, for this first derivative calculation, the argument is a single control-direction pair.We access the tangent-linear variable, containing the value of \(\left( dz/dm \right) \zeta\), using
var_tlm
, and using the same control-direction pair.
In fact more has happened here – if we now display the information recorded by tlm_adjoint:
[7]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
Block 0
Equation 0, FloatEquation solving for f_31 (id 31)
Dependency 0, f_31 (id 31), linear
Dependency 1, y (id 28), non-linear
Equation 1, FloatEquation solving for f_31_tlm((x,y),(zeta_x,zeta_y)) (id 33)
Dependency 0, f_31_tlm((x,y),(zeta_x,zeta_y)) (id 33), linear
Dependency 1, y (id 28), non-linear
Dependency 2, zeta_y (id 30), non-linear
Equation 2, FloatEquation solving for f_35 (id 35)
Dependency 0, f_35 (id 35), linear
Dependency 1, y (id 28), non-linear
Dependency 2, f_31 (id 31), non-linear
Equation 3, FloatEquation solving for f_35_tlm((x,y),(zeta_x,zeta_y)) (id 37)
Dependency 0, f_35_tlm((x,y),(zeta_x,zeta_y)) (id 37), linear
Dependency 1, y (id 28), non-linear
Dependency 2, zeta_y (id 30), non-linear
Dependency 3, f_31 (id 31), non-linear
Dependency 4, f_31_tlm((x,y),(zeta_x,zeta_y)) (id 33), non-linear
Equation 4, FloatEquation solving for f_39 (id 39)
Dependency 0, f_39 (id 39), linear
Dependency 1, f_35 (id 35), non-linear
Equation 5, FloatEquation solving for f_39_tlm((x,y),(zeta_x,zeta_y)) (id 41)
Dependency 0, f_39_tlm((x,y),(zeta_x,zeta_y)) (id 41), linear
Dependency 1, f_35 (id 35), non-linear
Dependency 2, f_35_tlm((x,y),(zeta_x,zeta_y)) (id 37), non-linear
Equation 6, FloatEquation solving for f_43 (id 43)
Dependency 0, f_43 (id 43), linear
Dependency 1, x (id 27), non-linear
Dependency 2, f_39 (id 39), non-linear
Equation 7, FloatEquation solving for f_43_tlm((x,y),(zeta_x,zeta_y)) (id 46)
Dependency 0, f_43_tlm((x,y),(zeta_x,zeta_y)) (id 46), linear
Dependency 1, x (id 27), non-linear
Dependency 2, zeta_x (id 29), non-linear
Dependency 3, f_39 (id 39), non-linear
Dependency 4, f_39_tlm((x,y),(zeta_x,zeta_y)) (id 41), non-linear
Storage:
Storing initial conditions: yes
Storing equation non-linear dependencies: yes
Initial conditions stored: 4
Initial conditions referenced: 0
Checkpointing:
Method: memory
we now see that there are eight FloatEquation
records – the original four, and four new ones. The extra ones correspond to the tangent-linear calculations. tlm_adjoint has recorded the original forward calculations, and also recorded the tangent-linear calculations.
Computing second derivatives using an adjoint of a tangent-linear
Since tlm_adjoint has recorded both forward and tangent-linear calculations, we can now compute second derivative information using an adjoint associated with a tangent-linear. Specifically we can compute a Hessian action on \(\zeta\),
The inner directional derivative appearing here is computed using the tangent-linear method, and the outer derivative is computed by applying the adjoint method to the tangent-linear and forward calculations. In code we simply use compute_gradient
to compute the derivative of the tangent-linear result:
[8]:
from tlm_adjoint import *
import numpy as np
reset_manager()
def forward(x, y):
z = x * np.sin(y * np.exp(y))
return z
x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")
m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
configure_tlm((m, zeta))
start_manager()
z = forward(x, y)
stop_manager()
dz_dm_zeta = var_tlm(z, (m, zeta))
print(f"{float(dz_dm_zeta)=}")
d2z_dm_zeta_dx, d2z_dm_zeta_dy = compute_gradient(dz_dm_zeta, m)
print(f"{float(d2z_dm_zeta_dx)=}")
print(f"{float(d2z_dm_zeta_dy)=}")
float(dz_dm_zeta)=0.2005295584792861
float(d2z_dm_zeta_dx)=-1.7764708733484629
float(d2z_dm_zeta_dy)=-49.4290736694207
Computing second derivatives using a tangent-linear of a tangent-linear
We can also compute second derivative information using a tangent-linear associated with a tangent-linear. For example if we define \(e_1 = \left( 1, 0 \right)^T\), then we can find the first component of the previously computed Hessian action on \(\zeta\) via
That is, here we now want to compute a directional derivative of a directional derivative, and we compute this using a tangent-linear associated with the previous tangent-linear and forward calculations.
tlm_adjoint handles this case by supplying more arguments to configure_tlm
:
[9]:
from tlm_adjoint import *
import numpy as np
reset_manager()
def forward(x, y):
z = x * np.sin(y * np.exp(y))
return z
x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")
m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
e_1 = (Float(1.0, name="e_1_x"), Float(0.0, name="e_1_y"))
configure_tlm((m, zeta), (m, e_1))
start_manager()
z = forward(x, y)
stop_manager()
dz_dm_zeta = var_tlm(z, (m, zeta))
dz_dx = var_tlm(z, (m, e_1))
d2z_dm_zeta_dx = var_tlm(z, (m, zeta), (m, e_1))
print(f"{float(dz_dm_zeta)=}")
print(f"{float(dz_dx)=}")
print(f"{float(d2z_dm_zeta_dx)=}")
float(dz_dm_zeta)=0.2005295584792861
float(dz_dx)=0.9885002159138745
float(d2z_dm_zeta_dx)=-1.7764708733484629
The first control-direction pair supplied to configure_tlm
indicates that we seek to compute directional derivatives of the forward with respect to the control defined by m
with direction defined by zeta
. The second control-direction pair indicates that we seek to compute directional deriatives of these directional derivatives, with respect to the control defined by m
and with direction defined by e_1
. As a side-effect we find that we also compute the directional
derivatives of the forward with respect to the control defined by m
with direction defined by e_1
.
We then access the tangent-linear variables using var_tlm
, supplying two control-variable pairs to access a second order tangent-linear variable.
As before, tlm_adjoint has not just performed the tangent-linear calculations – if we display the information recorded by tlm_adjoint:
[10]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
Block 0
Equation 0, FloatEquation solving for f_103 (id 103)
Dependency 0, f_103 (id 103), linear
Dependency 1, y (id 98), non-linear
Equation 1, FloatEquation solving for f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105)
Dependency 0, f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105), linear
Dependency 1, y (id 98), non-linear
Dependency 2, zeta_y (id 100), non-linear
Equation 2, FloatEquation solving for f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107)
Dependency 0, f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107), linear
Dependency 1, y (id 98), non-linear
Dependency 2, e_1_y (id 102), non-linear
Equation 3, FloatEquation solving for f_103_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 110)
Dependency 0, f_103_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 110), linear
Dependency 1, y (id 98), non-linear
Dependency 2, zeta_y (id 100), non-linear
Dependency 3, e_1_y (id 102), non-linear
Dependency 4, zeta_y_tlm((x,y),(e_1_x,e_1_y)) (id 109), non-linear
Equation 4, FloatEquation solving for f_112 (id 112)
Dependency 0, f_112 (id 112), linear
Dependency 1, y (id 98), non-linear
Dependency 2, f_103 (id 103), non-linear
Equation 5, FloatEquation solving for f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114)
Dependency 0, f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114), linear
Dependency 1, y (id 98), non-linear
Dependency 2, zeta_y (id 100), non-linear
Dependency 3, f_103 (id 103), non-linear
Dependency 4, f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105), non-linear
Equation 6, FloatEquation solving for f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116)
Dependency 0, f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116), linear
Dependency 1, y (id 98), non-linear
Dependency 2, e_1_y (id 102), non-linear
Dependency 3, f_103 (id 103), non-linear
Dependency 4, f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107), non-linear
Equation 7, FloatEquation solving for f_112_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 118)
Dependency 0, f_112_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 118), linear
Dependency 1, y (id 98), non-linear
Dependency 2, zeta_y (id 100), non-linear
Dependency 3, e_1_y (id 102), non-linear
Dependency 4, f_103 (id 103), non-linear
Dependency 5, f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105), non-linear
Dependency 6, f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107), non-linear
Dependency 7, zeta_y_tlm((x,y),(e_1_x,e_1_y)) (id 109), non-linear
Dependency 8, f_103_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 110), non-linear
Equation 8, FloatEquation solving for f_120 (id 120)
Dependency 0, f_120 (id 120), linear
Dependency 1, f_112 (id 112), non-linear
Equation 9, FloatEquation solving for f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122)
Dependency 0, f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122), linear
Dependency 1, f_112 (id 112), non-linear
Dependency 2, f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114), non-linear
Equation 10, FloatEquation solving for f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124)
Dependency 0, f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124), linear
Dependency 1, f_112 (id 112), non-linear
Dependency 2, f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116), non-linear
Equation 11, FloatEquation solving for f_120_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 126)
Dependency 0, f_120_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 126), linear
Dependency 1, f_112 (id 112), non-linear
Dependency 2, f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114), non-linear
Dependency 3, f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116), non-linear
Dependency 4, f_112_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 118), non-linear
Equation 12, FloatEquation solving for f_128 (id 128)
Dependency 0, f_128 (id 128), linear
Dependency 1, x (id 97), non-linear
Dependency 2, f_120 (id 120), non-linear
Equation 13, FloatEquation solving for f_128_tlm((x,y),(zeta_x,zeta_y)) (id 131)
Dependency 0, f_128_tlm((x,y),(zeta_x,zeta_y)) (id 131), linear
Dependency 1, x (id 97), non-linear
Dependency 2, zeta_x (id 99), non-linear
Dependency 3, f_120 (id 120), non-linear
Dependency 4, f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122), non-linear
Equation 14, FloatEquation solving for f_128_tlm((x,y),(e_1_x,e_1_y)) (id 134)
Dependency 0, f_128_tlm((x,y),(e_1_x,e_1_y)) (id 134), linear
Dependency 1, x (id 97), non-linear
Dependency 2, e_1_x (id 101), non-linear
Dependency 3, f_120 (id 120), non-linear
Dependency 4, f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124), non-linear
Equation 15, FloatEquation solving for f_128_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 138)
Dependency 0, f_128_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 138), linear
Dependency 1, x (id 97), non-linear
Dependency 2, zeta_x (id 99), non-linear
Dependency 3, e_1_x (id 101), non-linear
Dependency 4, f_120 (id 120), non-linear
Dependency 5, f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122), non-linear
Dependency 6, f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124), non-linear
Dependency 7, f_120_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 126), non-linear
Dependency 8, zeta_x_tlm((x,y),(e_1_x,e_1_y)) (id 137), non-linear
Storage:
Storing initial conditions: yes
Storing equation non-linear dependencies: yes
Initial conditions stored: 8
Initial conditions referenced: 0
Checkpointing:
Method: memory
we now find that there are sixteen FloatEquation
records, constituting the forward and all tangent-linear calculations.
Computing third derivatives using an adjoint of a tangent-linear of a tangent-linear
We can now compute the derivative of the tangent-linear-computed second derivative by simply handing the second order tangent-linear variable to compute_gradient
. This applies the adjoint method to the higher order tangent-linear calculations and the forward calculations, computing
[11]:
from tlm_adjoint import *
import numpy as np
reset_manager()
def forward(x, y):
z = x * np.sin(y * np.exp(y))
return z
x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")
m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
e_1 = (Float(1.0, name="e_1_x"), Float(0.0, name="e_1_y"))
configure_tlm((m, zeta), (m, e_1))
start_manager()
z = forward(x, y)
stop_manager()
dz_dm_zeta = var_tlm(z, (m, zeta))
dz_dx = var_tlm(z, (m, e_1))
d2z_dm_zeta_dx = var_tlm(z, (m, zeta), (m, e_1))
print(f"{float(dz_dm_zeta)=}")
print(f"{float(dz_dx)=}")
print(f"{float(d2z_dm_zeta_dx)=}")
d3z_dm_zeta_dx_dx, d3z_dm_zeta_dx_dy = compute_gradient(d2z_dm_zeta_dx, m)
print(f"{float(d3z_dm_zeta_dx_dx)=}")
print(f"{float(d3z_dm_zeta_dx_dy)=}")
float(dz_dm_zeta)=0.2005295584792861
float(dz_dx)=0.9885002159138745
float(d2z_dm_zeta_dx)=-1.7764708733484629
float(d3z_dm_zeta_dx_dx)=0.0
float(d3z_dm_zeta_dx_dy)=-48.244759753855064
Higher order
The approach now generalizes.
Supplying further arguments to
configure_tlm
indicates directional derivatives of directional derivatives, defining a tangent-linear calculation of increasingly high order.Supplying these arguments to
var_tlm
accesses the higher-order tangent-linear variables.These higher order tangent-linear variables can be handed to
compute_gradient
to compute derivatives of the higher order derivatives using the adjoint method.