Getting started with tlm_adjoint

This notebook introduces derivative calculations using tlm_adjoint.

tlm_adjoint is primarily intended for first derivative calculations using the adjoint method, and Hessian information calculations using the adjoint method applied to tangent-linear and forward calculations. However the approach used by tlm_adjoint generalizes to higher order.

The approach used by tlm_adjoint for higher order differentiation is described in:

  • James R. Maddison, Daniel N. Goldberg, and Benjamin D. Goddard, ‘Automated calculation of higher order partial differential equation constrained derivative information’, SIAM Journal on Scientific Computing, 41(5), pp. C417–C445, 2019, doi: 10.1137/18M1209465

A floating point example

We consider a simple floating point calculation:

[1]:
import numpy as np

x = 1.0
y = 0.25 * np.pi
z = x * np.sin(y * np.exp(y))

tlm_adjoint is designed for high-level algorithmic differentiation, and not this type of low-level floating point calculation. However it can still process simple floating point calculations, so to introduce the key ideas we do that here. We consider differentiating z with respect to x and y – being precise, we mean computing the derivative of the function used to compute z with respect to the variables defined by x and y.

Adding tlm_adjoint

We first modify the code so that tlm_adjoint processes the calculations:

[2]:
from tlm_adjoint import *

import numpy as np

reset_manager()


def forward(x, y):
    z = x * np.sin(y * np.exp(y))
    return z


x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")

start_manager()
z = forward(x, y)
stop_manager()
[2]:
(True, True)

The key changes here are:

  • To import tlm_adjoint.

  • Controlling the ‘manager’ – an object tlm_adjoint uses to process equations. The manager is first reset using reset_manager. This clears any previous processing, and disables the manager. start_manager and stop_manager are then used to enable the manager just when it is needed.

  • Defining x and y to be of type Float. Calculations involving x and y are recorded by the manager. The result of the calculations – here z – will have the same type, and we can access its value with float(z).

Let’s display the information recorded by tlm_adjoint:

[3]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
  Block 0
    Equation 0, FloatEquation solving for f_2 (id 2)
      Dependency 0, f_2 (id 2), linear
      Dependency 1, y (id 1), non-linear
    Equation 1, FloatEquation solving for f_4 (id 4)
      Dependency 0, f_4 (id 4), linear
      Dependency 1, y (id 1), non-linear
      Dependency 2, f_2 (id 2), non-linear
    Equation 2, FloatEquation solving for f_6 (id 6)
      Dependency 0, f_6 (id 6), linear
      Dependency 1, f_4 (id 4), non-linear
    Equation 3, FloatEquation solving for f_8 (id 8)
      Dependency 0, f_8 (id 8), linear
      Dependency 1, x (id 0), non-linear
      Dependency 2, f_6 (id 6), non-linear
Storage:
  Storing initial conditions: yes
  Storing equation non-linear dependencies: yes
  Initial conditions stored: 2
  Initial conditions referenced: 0
Checkpointing:
  Method: memory

The key feature here is that there are four FloatEquation records, corresponding to the four floating point calculations – evaluation using np.exp and np.sin, and two multiplications.

Computing derivatives using an adjoint

The compute_gradient function can be used to differentiate z with respect to x and y using the adjoint method:

[4]:
dz_dx, dz_dy = compute_gradient(z, (x, y))

print(f"{float(dz_dx)=}")
print(f"{float(dz_dy)=}")
float(dz_dx)=0.9885002159138745
float(dz_dy)=-0.592156957782821

For a simple check of the result, we can compare with the result from finite differencing, here using second order centered finite differencing:

[5]:
def dJ_dm(J, m, *, eps=1.0e-7):
    return (J(m + eps) - J(m - eps)) / (2.0 * eps)


print(f"dz_dx approximation = {dJ_dm(lambda x: forward(x, float(y)), float(x))}")
print(f"dz_dy approximation = {dJ_dm(lambda y: forward(float(x), y), float(y))}")
dz_dx approximation = 0.9885002155707312
dz_dy approximation = -0.5921569579125929

Computing derivatives using a tangent-linear

A tangent-linear computes directional derivatives with respect to a given control and with a given direction.

Here we consider the forward to be a function of x and y, computing a value of z, denoted \(z \left( x, y \right)\). We consider the control \(m = \left( x, y \right)^T\), and a direction \(\zeta = \left( 2, 3 \right)^T\). We can then use a tangent-linear to compute the directional derivative

\[\frac{dz}{dm} \zeta = 2 \frac{dz}{dx} + 3 \frac{dz}{dy},\]

where vector derivatives are notated using row vectors.

In most cases tlm_adjoint needs to be told what tangent-linear calculations to perform ahead of the forward calculations. configure_tlm provides this information to tlm_adjoint:

[6]:
from tlm_adjoint import *

import numpy as np

reset_manager()


def forward(x, y):
    z = x * np.sin(y * np.exp(y))
    return z


x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")

m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
configure_tlm((m, zeta))

start_manager()
z = forward(x, y)
stop_manager()

dz_dm_zeta = var_tlm(z, (m, zeta))

print(f"{float(dz_dm_zeta)=}")
float(dz_dm_zeta)=0.2005295584792861

There are three new changes:

  • The control \(m\) and direction \(\zeta\) are defined using m and zeta respectively.

  • Tangent-linear calculations are configured before the forward calculations are performed, using configure_tlm. Note that here, for this first derivative calculation, the argument is a single control-direction pair.

  • We access the tangent-linear variable, containing the value of \(\left( dz/dm \right) \zeta\), using var_tlm, and using the same control-direction pair.

In fact more has happened here – if we now display the information recorded by tlm_adjoint:

[7]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
  Block 0
    Equation 0, FloatEquation solving for f_31 (id 31)
      Dependency 0, f_31 (id 31), linear
      Dependency 1, y (id 28), non-linear
    Equation 1, FloatEquation solving for f_31_tlm((x,y),(zeta_x,zeta_y)) (id 33)
      Dependency 0, f_31_tlm((x,y),(zeta_x,zeta_y)) (id 33), linear
      Dependency 1, y (id 28), non-linear
      Dependency 2, zeta_y (id 30), non-linear
    Equation 2, FloatEquation solving for f_35 (id 35)
      Dependency 0, f_35 (id 35), linear
      Dependency 1, y (id 28), non-linear
      Dependency 2, f_31 (id 31), non-linear
    Equation 3, FloatEquation solving for f_35_tlm((x,y),(zeta_x,zeta_y)) (id 37)
      Dependency 0, f_35_tlm((x,y),(zeta_x,zeta_y)) (id 37), linear
      Dependency 1, y (id 28), non-linear
      Dependency 2, zeta_y (id 30), non-linear
      Dependency 3, f_31 (id 31), non-linear
      Dependency 4, f_31_tlm((x,y),(zeta_x,zeta_y)) (id 33), non-linear
    Equation 4, FloatEquation solving for f_39 (id 39)
      Dependency 0, f_39 (id 39), linear
      Dependency 1, f_35 (id 35), non-linear
    Equation 5, FloatEquation solving for f_39_tlm((x,y),(zeta_x,zeta_y)) (id 41)
      Dependency 0, f_39_tlm((x,y),(zeta_x,zeta_y)) (id 41), linear
      Dependency 1, f_35 (id 35), non-linear
      Dependency 2, f_35_tlm((x,y),(zeta_x,zeta_y)) (id 37), non-linear
    Equation 6, FloatEquation solving for f_43 (id 43)
      Dependency 0, f_43 (id 43), linear
      Dependency 1, x (id 27), non-linear
      Dependency 2, f_39 (id 39), non-linear
    Equation 7, FloatEquation solving for f_43_tlm((x,y),(zeta_x,zeta_y)) (id 46)
      Dependency 0, f_43_tlm((x,y),(zeta_x,zeta_y)) (id 46), linear
      Dependency 1, x (id 27), non-linear
      Dependency 2, zeta_x (id 29), non-linear
      Dependency 3, f_39 (id 39), non-linear
      Dependency 4, f_39_tlm((x,y),(zeta_x,zeta_y)) (id 41), non-linear
Storage:
  Storing initial conditions: yes
  Storing equation non-linear dependencies: yes
  Initial conditions stored: 4
  Initial conditions referenced: 0
Checkpointing:
  Method: memory

we now see that there are eight FloatEquation records – the original four, and four new ones. The extra ones correspond to the tangent-linear calculations. tlm_adjoint has recorded the original forward calculations, and also recorded the tangent-linear calculations.

Computing second derivatives using an adjoint of a tangent-linear

Since tlm_adjoint has recorded both forward and tangent-linear calculations, we can now compute second derivative information using an adjoint associated with a tangent-linear. Specifically we can compute a Hessian action on \(\zeta\),

\[H \zeta = \frac{d}{dm} \left( \frac{dz}{dm} \zeta \right)^T.\]

The inner directional derivative appearing here is computed using the tangent-linear method, and the outer derivative is computed by applying the adjoint method to the tangent-linear and forward calculations. In code we simply use compute_gradient to compute the derivative of the tangent-linear result:

[8]:
from tlm_adjoint import *

import numpy as np

reset_manager()


def forward(x, y):
    z = x * np.sin(y * np.exp(y))
    return z


x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")

m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
configure_tlm((m, zeta))

start_manager()
z = forward(x, y)
stop_manager()

dz_dm_zeta = var_tlm(z, (m, zeta))

print(f"{float(dz_dm_zeta)=}")

d2z_dm_zeta_dx, d2z_dm_zeta_dy = compute_gradient(dz_dm_zeta, m)

print(f"{float(d2z_dm_zeta_dx)=}")
print(f"{float(d2z_dm_zeta_dy)=}")
float(dz_dm_zeta)=0.2005295584792861
float(d2z_dm_zeta_dx)=-1.7764708733484629
float(d2z_dm_zeta_dy)=-49.4290736694207

Computing second derivatives using a tangent-linear of a tangent-linear

We can also compute second derivative information using a tangent-linear associated with a tangent-linear. For example if we define \(e_1 = \left( 1, 0 \right)^T\), then we can find the first component of the previously computed Hessian action on \(\zeta\) via

\[e_1^T H \zeta = \left[ \frac{d}{dm} \left( \frac{dz}{dm} \zeta \right) \right] e_1.\]

That is, here we now want to compute a directional derivative of a directional derivative, and we compute this using a tangent-linear associated with the previous tangent-linear and forward calculations.

tlm_adjoint handles this case by supplying more arguments to configure_tlm:

[9]:
from tlm_adjoint import *

import numpy as np

reset_manager()


def forward(x, y):
    z = x * np.sin(y * np.exp(y))
    return z


x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")

m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
e_1 = (Float(1.0, name="e_1_x"), Float(0.0, name="e_1_y"))
configure_tlm((m, zeta), (m, e_1))

start_manager()
z = forward(x, y)
stop_manager()

dz_dm_zeta = var_tlm(z, (m, zeta))
dz_dx = var_tlm(z, (m, e_1))
d2z_dm_zeta_dx = var_tlm(z, (m, zeta), (m, e_1))

print(f"{float(dz_dm_zeta)=}")
print(f"{float(dz_dx)=}")
print(f"{float(d2z_dm_zeta_dx)=}")
float(dz_dm_zeta)=0.2005295584792861
float(dz_dx)=0.9885002159138745
float(d2z_dm_zeta_dx)=-1.7764708733484629

The first control-direction pair supplied to configure_tlm indicates that we seek to compute directional derivatives of the forward with respect to the control defined by m with direction defined by zeta. The second control-direction pair indicates that we seek to compute directional deriatives of these directional derivatives, with respect to the control defined by m and with direction defined by e_1. As a side-effect we find that we also compute the directional derivatives of the forward with respect to the control defined by m with direction defined by e_1.

We then access the tangent-linear variables using var_tlm, supplying two control-variable pairs to access a second order tangent-linear variable.

As before, tlm_adjoint has not just performed the tangent-linear calculations – if we display the information recorded by tlm_adjoint:

[10]:
manager_info()
Equation manager status:
Annotation state: AnnotationState.STOPPED
Tangent-linear state: TangentLinearState.STOPPED
Equations:
  Block 0
    Equation 0, FloatEquation solving for f_103 (id 103)
      Dependency 0, f_103 (id 103), linear
      Dependency 1, y (id 98), non-linear
    Equation 1, FloatEquation solving for f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105)
      Dependency 0, f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, zeta_y (id 100), non-linear
    Equation 2, FloatEquation solving for f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107)
      Dependency 0, f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, e_1_y (id 102), non-linear
    Equation 3, FloatEquation solving for f_103_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 110)
      Dependency 0, f_103_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 110), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, zeta_y (id 100), non-linear
      Dependency 3, e_1_y (id 102), non-linear
      Dependency 4, zeta_y_tlm((x,y),(e_1_x,e_1_y)) (id 109), non-linear
    Equation 4, FloatEquation solving for f_112 (id 112)
      Dependency 0, f_112 (id 112), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, f_103 (id 103), non-linear
    Equation 5, FloatEquation solving for f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114)
      Dependency 0, f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, zeta_y (id 100), non-linear
      Dependency 3, f_103 (id 103), non-linear
      Dependency 4, f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105), non-linear
    Equation 6, FloatEquation solving for f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116)
      Dependency 0, f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, e_1_y (id 102), non-linear
      Dependency 3, f_103 (id 103), non-linear
      Dependency 4, f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107), non-linear
    Equation 7, FloatEquation solving for f_112_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 118)
      Dependency 0, f_112_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 118), linear
      Dependency 1, y (id 98), non-linear
      Dependency 2, zeta_y (id 100), non-linear
      Dependency 3, e_1_y (id 102), non-linear
      Dependency 4, f_103 (id 103), non-linear
      Dependency 5, f_103_tlm((x,y),(zeta_x,zeta_y)) (id 105), non-linear
      Dependency 6, f_103_tlm((x,y),(e_1_x,e_1_y)) (id 107), non-linear
      Dependency 7, zeta_y_tlm((x,y),(e_1_x,e_1_y)) (id 109), non-linear
      Dependency 8, f_103_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 110), non-linear
    Equation 8, FloatEquation solving for f_120 (id 120)
      Dependency 0, f_120 (id 120), linear
      Dependency 1, f_112 (id 112), non-linear
    Equation 9, FloatEquation solving for f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122)
      Dependency 0, f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122), linear
      Dependency 1, f_112 (id 112), non-linear
      Dependency 2, f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114), non-linear
    Equation 10, FloatEquation solving for f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124)
      Dependency 0, f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124), linear
      Dependency 1, f_112 (id 112), non-linear
      Dependency 2, f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116), non-linear
    Equation 11, FloatEquation solving for f_120_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 126)
      Dependency 0, f_120_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 126), linear
      Dependency 1, f_112 (id 112), non-linear
      Dependency 2, f_112_tlm((x,y),(zeta_x,zeta_y)) (id 114), non-linear
      Dependency 3, f_112_tlm((x,y),(e_1_x,e_1_y)) (id 116), non-linear
      Dependency 4, f_112_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 118), non-linear
    Equation 12, FloatEquation solving for f_128 (id 128)
      Dependency 0, f_128 (id 128), linear
      Dependency 1, x (id 97), non-linear
      Dependency 2, f_120 (id 120), non-linear
    Equation 13, FloatEquation solving for f_128_tlm((x,y),(zeta_x,zeta_y)) (id 131)
      Dependency 0, f_128_tlm((x,y),(zeta_x,zeta_y)) (id 131), linear
      Dependency 1, x (id 97), non-linear
      Dependency 2, zeta_x (id 99), non-linear
      Dependency 3, f_120 (id 120), non-linear
      Dependency 4, f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122), non-linear
    Equation 14, FloatEquation solving for f_128_tlm((x,y),(e_1_x,e_1_y)) (id 134)
      Dependency 0, f_128_tlm((x,y),(e_1_x,e_1_y)) (id 134), linear
      Dependency 1, x (id 97), non-linear
      Dependency 2, e_1_x (id 101), non-linear
      Dependency 3, f_120 (id 120), non-linear
      Dependency 4, f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124), non-linear
    Equation 15, FloatEquation solving for f_128_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 138)
      Dependency 0, f_128_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 138), linear
      Dependency 1, x (id 97), non-linear
      Dependency 2, zeta_x (id 99), non-linear
      Dependency 3, e_1_x (id 101), non-linear
      Dependency 4, f_120 (id 120), non-linear
      Dependency 5, f_120_tlm((x,y),(zeta_x,zeta_y)) (id 122), non-linear
      Dependency 6, f_120_tlm((x,y),(e_1_x,e_1_y)) (id 124), non-linear
      Dependency 7, f_120_tlm((x,y),(zeta_x,zeta_y))_tlm((x,y),(e_1_x,e_1_y)) (id 126), non-linear
      Dependency 8, zeta_x_tlm((x,y),(e_1_x,e_1_y)) (id 137), non-linear
Storage:
  Storing initial conditions: yes
  Storing equation non-linear dependencies: yes
  Initial conditions stored: 8
  Initial conditions referenced: 0
Checkpointing:
  Method: memory

we now find that there are sixteen FloatEquation records, constituting the forward and all tangent-linear calculations.

Computing third derivatives using an adjoint of a tangent-linear of a tangent-linear

We can now compute the derivative of the tangent-linear-computed second derivative by simply handing the second order tangent-linear variable to compute_gradient. This applies the adjoint method to the higher order tangent-linear calculations and the forward calculations, computing

\[\frac{d}{dm} \left( e_1^T H \zeta \right) = \frac{d}{dm} \left[ \left[ \frac{d}{dm} \left( \frac{dz}{dm} \zeta \right) \right] e_1 \right].\]
[11]:
from tlm_adjoint import *

import numpy as np

reset_manager()


def forward(x, y):
    z = x * np.sin(y * np.exp(y))
    return z


x = Float(1.0, name="x")
y = Float(0.25 * np.pi, name="y")

m = (x, y)
zeta = (Float(2.0, name="zeta_x"), Float(3.0, name="zeta_y"))
e_1 = (Float(1.0, name="e_1_x"), Float(0.0, name="e_1_y"))
configure_tlm((m, zeta), (m, e_1))

start_manager()
z = forward(x, y)
stop_manager()

dz_dm_zeta = var_tlm(z, (m, zeta))
dz_dx = var_tlm(z, (m, e_1))
d2z_dm_zeta_dx = var_tlm(z, (m, zeta), (m, e_1))

print(f"{float(dz_dm_zeta)=}")
print(f"{float(dz_dx)=}")
print(f"{float(d2z_dm_zeta_dx)=}")

d3z_dm_zeta_dx_dx, d3z_dm_zeta_dx_dy = compute_gradient(d2z_dm_zeta_dx, m)

print(f"{float(d3z_dm_zeta_dx_dx)=}")
print(f"{float(d3z_dm_zeta_dx_dy)=}")
float(dz_dm_zeta)=0.2005295584792861
float(dz_dx)=0.9885002159138745
float(d2z_dm_zeta_dx)=-1.7764708733484629
float(d3z_dm_zeta_dx_dx)=0.0
float(d3z_dm_zeta_dx_dy)=-48.244759753855064

Higher order

The approach now generalizes.

  • Supplying further arguments to configure_tlm indicates directional derivatives of directional derivatives, defining a tangent-linear calculation of increasingly high order.

  • Supplying these arguments to var_tlm accesses the higher-order tangent-linear variables.

  • These higher order tangent-linear variables can be handed to compute_gradient to compute derivatives of the higher order derivatives using the adjoint method.