tlm_adjoint.jax

Module Contents

tlm_adjoint.jax.set_default_jax_dtype(dtype)

Set the default data type used by Vector objects.

Parameters:

dtype – The default data type.

class tlm_adjoint.jax.VectorSpace(n, *, dtype=None, comm=None)

A vector space.

Parameters:
  • dtype – The data type associated with the space. Typically numpy.double or numpy.cdouble.

  • comm – The communicator associated with the space.

property dtype

The data type associated with the space.

property comm

The communicator associated with the space.

property local_size

The number of local degrees of freedom.

property global_size

The global number of degrees of freedom.

property ownership_range

A tuple (n0, n1), indicating that slice(n0, n1) is the range of nodes in the global vector owned by this process.

rdtype()

The real data type associated with the space.

class tlm_adjoint.jax.Vector(V, *, name=None, space_type='primal', static=False, cache=None, dtype=None, comm=None)

Vector, with degrees of freedom stored as a JAX array.

Parameters:
  • V – A VectorSpace, an int defining the number of local degrees of freedom, or an ndim 1 array defining the local degrees of freedom.

  • name – A str name for the Vector.

  • space_type – The space type for the Vector. ‘primal’, ‘dual’, ‘conjugate’, or ‘conjugate_dual’.

  • static – Defines whether the Vector is static, meaning that it is stored by reference in checkpointing/replay, and an associated tangent-linear variable is zero.

  • cache – Defines whether results involving the Vector may be cached. Default static.

  • dtype – The data type. Ignored if V is a VectorSpace.

  • comm – A communicator. Ignored if V is a VectorSpace.

property name

The str name of the Vector.

property value

For a Vector with one element, the value of the element.

The value may also be accessed by casting using float or complex.

Returns:

The value.

property space

The VectorSpace for the Vector.

property space_type

The space type for the Vector.

property vector

A JAX array storing the local degrees of freedom.

new(y=None, *, name=None, static=False, cache=None)

Return a new Vector, with the same VectorSpace and space type as this Vector.

Parameters:

y – Defines a value for the new Vector.

Returns:

The new Vector.

Remaining arguments are as for the Vector constructor.

assign(y)

Vector assignment.

Parameters:

y – A numbers.Complex, Vector, or ndim 1 array defining the value.

Returns:

The Vector.

addto(y)

Vector in-place addition.

Parameters:

y – A numbers.Complex, Vector, or ndim 1 array defining the value to add.

class tlm_adjoint.jax.VectorEquation(X, Y, fn, *, with_tlm=True, _forward_eq=None)

JAX interface. fn should be a callable

def fn(y0, y1, ...):
    ...
    return x0, x1, ...

where the y0, y1 are ndim 1 JAX arrays, and the x0, x1, are scalars or ndim 1 JAX arrays.

Parameters:
  • X – A Vector or a Sequence of Vector objects defining outputs, whose value is set by the return value from fn.

  • Y – A Vector or a Sequence of Vector objects defining the inputs, whose values are passed to fn.

  • fn – A callable.

  • with_tlm – Whether to annotate an equation solving for the forward and all tangent-linears (with_tlm=True), or solving only for the forward (with_tlm=False).

solve(*, annotate=None, tlm=None)

Compute the forward solution.

Parameters:
  • annotate – Whether the EquationManager should record the solution of equations.

  • tlm – Whether tangent-linear equations should be solved.

forward_solve(X, deps=None)

Compute the forward solution.

Can assume that the currently active EquationManager is paused.

Parameters:
  • X – A variable if the forward solution has a single component, otherwise a Sequence of variables. May define an initial guess, and should be set by this method. Subclasses may replace this argument with x if the forward solution has a single component.

  • deps – A tuple of variables, defining values for dependencies. Only the elements corresponding to X may be modified. self.dependencies() should be used if not supplied.

adjoint_jacobian_solve(adj_X, nl_deps, B)

Compute an adjoint solution.

Parameters:
  • adj_X – Either None, or a variable (if the adjoint solution has a single component) or Sequence of variables (otherwise) defining the initial guess for an iterative solve. May be modified or returned. Subclasses may replace this argument with adj_x if the adjoint solution has a single component.

  • nl_deps – A Sequence of variables defining values for non-linear dependencies. Should not be modified.

  • B – The right-hand-side. A variable (if the adjoint solution has a single component) or Sequence of variables (otherwise) storing the value of the right-hand-side. May be modified or returned. Subclasses may replace this argument with b if the adjoint solution has a single component.

Returns:

A variable or Sequence of variables storing the value of the adjoint solution. May return None to indicate a value of zero.

subtract_adjoint_derivative_actions(adj_X, nl_deps, dep_Bs)

Subtract terms from other adjoint right-hand-sides.

Can be overridden for an optimized implementation, but otherwise uses Equation.adjoint_derivative_action().

Parameters:
  • adj_X – The adjoint solution. A variable if the adjoint solution has a single component, otherwise a Sequence of variables. Should not be modified. Subclasses may replace this argument with adj_x if the adjoint solution has a single component.

  • nl_deps – A Sequence of variables defining values for non-linear dependencies. Should not be modified.

  • dep_Bs – A Mapping whose items are (dep_index, dep_B). Each dep_B is an AdjointRHS which should be updated by subtracting adjoint derivative information computed by differentiating with respect to self.dependencies()[dep_index].

tangent_linear(tlm_map)

Derive an Equation corresponding to a tangent-linear operation.

Parameters:

tlm_map – A TangentLinearMap storing values for tangent-linear variables.

Returns:

An Equation, corresponding to the tangent-linear operation.

tlm_adjoint.jax.call_jax(X, Y, fn)

JAX interface. fn should be a callable

def fn(y0, y1, ...):
    ...
    return x0, x1, ...

where the y0, y1 are ndim 1 JAX arrays, and the x0, x1, are scalars or ndim 1 JAX arrays.

Parameters:
  • X – A Vector or a Sequence of Vector objects defining outputs, whose value is set by the return value from fn.

  • Y – A Vector or a Sequence of Vector objects defining the inputs, whose values are passed to fn.

  • fn – A callable.

tlm_adjoint.jax.new_jax(y, space=None, *, name=None)

Construct a new zero-valued Vector.

Parameters:
  • y – A variable.

  • space – The VectorSpace for the return value.

  • name – A str name.

Returns:

The Vector.

tlm_adjoint.jax.to_jax(y, space=None, *, name=None)

Convert a variable to a Vector.

Parameters:
  • y – A variable.

  • space – The VectorSpace for the return value.

  • name – A str name.

Returns:

The Vector.

tlm_adjoint.jax.new_jax_float(space=None, *, name=None, dtype=None, comm=None)

Create a new Vector with one element.

Parameters:
  • space – The VectorSpace.

  • name – A str name.

  • dtype – The data type. Ignored if space is supplied.

  • comm – A communicator. Ignored if space is supplied.

Returns:

A Vector with one element.