tlm_adjoint.jax
Module Contents
- tlm_adjoint.jax.set_default_jax_dtype(dtype)
Set the default data type used by
Vector
objects.- Parameters:
- dtypetype
The default data type. If None then the default JAX floating point scalar data type is used.
- class tlm_adjoint.jax.VectorSpace(n, *, dtype=None, comm=None)
A vector space.
- Parameters:
dtype – The data type associated with the space. Typically
numpy.double
ornumpy.cdouble
.comm – The communicator associated with the space.
- property dtype
The data type associated with the space.
- property comm
The communicator associated with the space.
- property local_size
The number of local degrees of freedom.
- property global_size
The global number of degrees of freedom.
- property ownership_range
A tuple (n0, n1), indicating that slice(n0, n1) is the range of nodes in the global vector owned by this process.
- rdtype()
The real data type associated with the space.
- class tlm_adjoint.jax.Vector(V, *, name=None, space_type='primal', static=False, cache=None, dtype=None, comm=None)
Vector, with degrees of freedom stored as a JAX array.
- Parameters:
V – A
VectorSpace
, anint
defining the number of local degrees of freedom, or an ndim 1 array defining the local degrees of freedom.space_type – The space type for the
Vector
. ‘primal’, ‘dual’, ‘conjugate’, or ‘conjugate_dual’.static – Defines whether the
Vector
is static, meaning that it is stored by reference in checkpointing/replay, and an associated tangent-linear variable is zero.cache – Defines whether results involving the
Vector
may be cached. Default static.dtype – The data type. Ignored if V is a
VectorSpace
.comm – A communicator. Ignored if V is a
VectorSpace
.
- property value
For a
Vector
with one element, the value of the element.The value may also be accessed by casting using
float
orcomplex
.- Returns:
The value.
- property space
The
VectorSpace
for theVector
.
- property vector
A JAX array storing the local degrees of freedom.
- new(y=None, *, name=None, static=False, cache=None)
Return a new
Vector
, with the sameVectorSpace
and space type as thisVector
.Remaining arguments are as for the
Vector
constructor.
- assign(y)
Vector
assignment.- Parameters:
y – A
numbers.Complex
,Vector
, or ndim 1 array defining the value.- Returns:
The
Vector
.
- addto(y)
Vector
in-place addition.- Parameters:
y – A
numbers.Complex
,Vector
, or ndim 1 array defining the value to add.
- class tlm_adjoint.jax.VectorEquation(X, Y, fn, *, with_tlm=True, _forward_eq=None)
JAX interface. fn should be a callable
def fn(y0, y1, ...): ... return x0, x1, ...
where the y0, y1 are ndim 1 JAX arrays, and the x0, x1, are scalars or ndim 1 JAX arrays.
- Parameters:
X – A
Vector
or aSequence
ofVector
objects defining outputs, whose value is set by the return value from fn.Y – A
Vector
or aSequence
ofVector
objects defining the inputs, whose values are passed to fn.fn – A callable.
with_tlm – Whether to annotate an equation solving for the forward and all tangent-linears (with_tlm=True), or solving only for the forward (with_tlm=False).
- solve(*, annotate=None, tlm=None)
Compute the forward solution.
- Parameters:
annotate – Whether the
EquationManager
should record the solution of equations.tlm – Whether tangent-linear equations should be solved.
- forward_solve(X, deps=None)
Compute the forward solution.
Can assume that the currently active
EquationManager
is paused.- Parameters:
X – A variable or a
Sequence
of variables storing the solution. May define an initial guess, and should be set by this method.deps – A
tuple
of variables, defining values for dependencies. Only the elements corresponding to X may be modified. self.dependencies() should be used if not supplied.
- adjoint_jacobian_solve(adj_X, nl_deps, B)
Compute an adjoint solution.
- Parameters:
adj_X – Either None, or a variable or
Sequence
of variables defining the initial guess for an iterative solve. May be modified or returned.nl_deps – A
Sequence
of variables defining values for non-linear dependencies. Should not be modified.B – The right-hand-side. A variable or
Sequence
of variables storing the value of the right-hand-side. May be modified or returned.
- Returns:
A variable or
Sequence
of variables storing the value of the adjoint solution. May return None to indicate a value of zero.
- subtract_adjoint_derivative_actions(adj_X, nl_deps, dep_Bs)
Subtract terms from other adjoint right-hand-sides.
Can be overridden for an optimized implementation, but otherwise uses
Equation.adjoint_derivative_action()
.- Parameters:
adj_X – The adjoint solution. A variable or a
Sequence
of variables. Should not be modified.nl_deps – A
Sequence
of variables defining values for non-linear dependencies. Should not be modified.dep_Bs – A
Mapping
whose items are (dep_index, dep_B). Each dep_B is anAdjointRHS
which should be updated by subtracting adjoint derivative information computed by differentiating with respect to self.dependencies()[dep_index].
- tangent_linear(tlm_map)
Derive an
Equation
corresponding to a tangent-linear operation.- Parameters:
tlm_map – A
TangentLinearMap
storing values for tangent-linear variables.- Returns:
An
Equation
, corresponding to the tangent-linear operation.
- tlm_adjoint.jax.call_jax(X, Y, fn)
JAX interface. fn should be a callable
def fn(y0, y1, ...): ... return x0, x1, ...
where the y0, y1 are ndim 1 JAX arrays, and the x0, x1, are scalars or ndim 1 JAX arrays.
- tlm_adjoint.jax.new_jax(y, space=None, *, name=None)
Construct a new zero-valued
Vector
.- Parameters:
y – A variable.
space – The
VectorSpace
for the return value.name – A
str
name.
- Returns:
The
Vector
.
- tlm_adjoint.jax.to_jax(y, space=None, *, name=None)
Convert a variable to a
Vector
.- Parameters:
y – A variable.
space – The
VectorSpace
for the return value.name – A
str
name.
- Returns:
The
Vector
.
- tlm_adjoint.jax.new_jax_float(space=None, *, name=None, dtype=None, comm=None)
Create a new
Vector
with one element.- Parameters:
space – The
VectorSpace
.name – A
str
name.dtype – The data type. Ignored if space is supplied.
comm – A communicator. Ignored if space is supplied.
- Returns:
A
Vector
with one element.