:orphan: :py:mod:`tlm_adjoint.jax` ========================= .. py:module:: tlm_adjoint.jax Module Contents --------------- .. py:function:: set_default_jax_dtype(dtype) Set the default data type used by :class:`.Vector` objects. :Parameters: **dtype** : type The default data type. If `None` then the default JAX floating point scalar data type is used. .. !! processed by numpydoc !! .. py:class:: VectorSpace(n, *, dtype=None, comm=None) A vector space. :arg dtype: The data type associated with the space. Typically :class:`numpy.double` or :class:`numpy.cdouble`. :arg comm: The communicator associated with the space. .. !! processed by numpydoc !! .. py:property:: dtype The data type associated with the space. .. !! processed by numpydoc !! .. py:property:: comm The communicator associated with the space. .. !! processed by numpydoc !! .. py:property:: local_size The number of local degrees of freedom. .. !! processed by numpydoc !! .. py:property:: global_size The global number of degrees of freedom. .. !! processed by numpydoc !! .. py:property:: ownership_range A tuple `(n0, n1)`, indicating that `slice(n0, n1)` is the range of nodes in the global vector owned by this process. .. !! processed by numpydoc !! .. py:method:: rdtype() The real data type associated with the space. .. !! processed by numpydoc !! .. py:class:: Vector(V, *, name=None, space_type='primal', static=False, cache=None, dtype=None, comm=None) Vector, with degrees of freedom stored as a JAX array. :arg V: A :class:`.VectorSpace`, an :class:`int` defining the number of local degrees of freedom, or an ndim 1 array defining the local degrees of freedom. :arg name: A :class:`str` name for the :class:`.Vector`. :arg space_type: The space type for the :class:`.Vector`. `'primal'`, `'dual'`, `'conjugate'`, or `'conjugate_dual'`. :arg static: Defines whether the :class:`.Vector` is static, meaning that it is stored by reference in checkpointing/replay, and an associated tangent-linear variable is zero. :arg cache: Defines whether results involving the :class:`.Vector` may be cached. Default `static`. :arg dtype: The data type. Ignored if `V` is a :class:`.VectorSpace`. :arg comm: A communicator. Ignored if `V` is a :class:`.VectorSpace`. .. !! processed by numpydoc !! .. py:property:: name The :class:`str` name of the :class:`.Vector`. .. !! processed by numpydoc !! .. py:property:: value For a :class:`.Vector` with one element, the value of the element. The value may also be accessed by casting using :class:`float` or :class:`complex`. :returns: The value. .. !! processed by numpydoc !! .. py:property:: space The :class:`.VectorSpace` for the :class:`.Vector`. .. !! processed by numpydoc !! .. py:property:: space_type The space type for the :class:`.Vector`. .. !! processed by numpydoc !! .. py:property:: vector A JAX array storing the local degrees of freedom. .. !! processed by numpydoc !! .. py:method:: new(y=None, *, name=None, static=False, cache=None) Return a new :class:`.Vector`, with the same :class:`.VectorSpace` and space type as this :class:`.Vector`. :arg y: Defines a value for the new :class:`.Vector`. :returns: The new :class:`.Vector`. Remaining arguments are as for the :class:`.Vector` constructor. .. !! processed by numpydoc !! .. py:method:: assign(y) :class:`.Vector` assignment. :arg y: A :class:`numbers.Complex`, :class:`.Vector`, or ndim 1 array defining the value. :returns: The :class:`.Vector`. .. !! processed by numpydoc !! .. py:method:: addto(y) :class:`.Vector` in-place addition. :arg y: A :class:`numbers.Complex`, :class:`.Vector`, or ndim 1 array defining the value to add. .. !! processed by numpydoc !! .. py:class:: VectorEquation(X, Y, fn, *, with_tlm=True, _forward_eq=None) JAX interface. `fn` should be a callable .. code-block:: python def fn(y0, y1, ...): ... return x0, x1, ... where the `y0`, `y1` are ndim 1 JAX arrays, and the `x0`, `x1`, are scalars or ndim 1 JAX arrays. :arg X: A :class:`.Vector` or a :class:`Sequence` of :class:`.Vector` objects defining outputs, whose value is set by the return value from `fn`. :arg Y: A :class:`.Vector` or a :class:`Sequence` of :class:`.Vector` objects defining the inputs, whose values are passed to `fn`. :arg fn: A callable. :arg with_tlm: Whether to annotate an equation solving for the forward and all tangent-linears (`with_tlm=True`), or solving only for the forward (`with_tlm=False`). .. !! processed by numpydoc !! .. py:method:: solve(*, annotate=None, tlm=None) Compute the forward solution. :arg annotate: Whether the :class:`.EquationManager` should record the solution of equations. :arg tlm: Whether tangent-linear equations should be solved. .. !! processed by numpydoc !! .. py:method:: forward_solve(X, deps=None) Compute the forward solution. Can assume that the currently active :class:`.EquationManager` is paused. :arg X: A variable or a :class:`Sequence` of variables storing the solution. May define an initial guess, and should be set by this method. :arg deps: A :class:`tuple` of variables, defining values for dependencies. Only the elements corresponding to `X` may be modified. `self.dependencies()` should be used if not supplied. .. !! processed by numpydoc !! .. py:method:: adjoint_jacobian_solve(adj_X, nl_deps, B) Compute an adjoint solution. :arg adj_X: Either `None`, or a variable or :class:`Sequence` of variables defining the initial guess for an iterative solve. May be modified or returned. :arg nl_deps: A :class:`Sequence` of variables defining values for non-linear dependencies. Should not be modified. :arg B: The right-hand-side. A variable or :class:`Sequence` of variables storing the value of the right-hand-side. May be modified or returned. :returns: A variable or :class:`Sequence` of variables storing the value of the adjoint solution. May return `None` to indicate a value of zero. .. !! processed by numpydoc !! .. py:method:: subtract_adjoint_derivative_actions(adj_X, nl_deps, dep_Bs) Subtract terms from other adjoint right-hand-sides. Can be overridden for an optimized implementation, but otherwise uses :meth:`.Equation.adjoint_derivative_action`. :arg adj_X: The adjoint solution. A variable or a :class:`Sequence` of variables. Should not be modified. :arg nl_deps: A :class:`Sequence` of variables defining values for non-linear dependencies. Should not be modified. :arg dep_Bs: A :class:`Mapping` whose items are `(dep_index, dep_B)`. Each `dep_B` is an :class:`.AdjointRHS` which should be updated by subtracting adjoint derivative information computed by differentiating with respect to `self.dependencies()[dep_index]`. .. !! processed by numpydoc !! .. py:method:: tangent_linear(tlm_map) Derive an :class:`.Equation` corresponding to a tangent-linear operation. :arg tlm_map: A :class:`.TangentLinearMap` storing values for tangent-linear variables. :returns: An :class:`.Equation`, corresponding to the tangent-linear operation. .. !! processed by numpydoc !! .. py:function:: call_jax(X, Y, fn) JAX interface. `fn` should be a callable .. code-block:: python def fn(y0, y1, ...): ... return x0, x1, ... where the `y0`, `y1` are ndim 1 JAX arrays, and the `x0`, `x1`, are scalars or ndim 1 JAX arrays. :arg X: A :class:`.Vector` or a :class:`Sequence` of :class:`.Vector` objects defining outputs, whose value is set by the return value from `fn`. :arg Y: A :class:`.Vector` or a :class:`Sequence` of :class:`.Vector` objects defining the inputs, whose values are passed to `fn`. :arg fn: A callable. .. !! processed by numpydoc !! .. py:function:: new_jax(y, space=None, *, name=None) Construct a new zero-valued :class:`.Vector`. :arg y: A variable. :arg space: The :class:`.VectorSpace` for the return value. :arg name: A :class:`str` name. :returns: The :class:`.Vector`. .. !! processed by numpydoc !! .. py:function:: to_jax(y, space=None, *, name=None) Convert a variable to a :class:`.Vector`. :arg y: A variable. :arg space: The :class:`.VectorSpace` for the return value. :arg name: A :class:`str` name. :returns: The :class:`.Vector`. .. !! processed by numpydoc !! .. py:function:: new_jax_float(space=None, *, name=None, dtype=None, comm=None) Create a new :class:`.Vector` with one element. :arg space: The :class:`.VectorSpace`. :arg name: A :class:`str` name. :arg dtype: The data type. Ignored if `space` is supplied. :arg comm: A communicator. Ignored if `space` is supplied. :returns: A :class:`.Vector` with one element. .. !! processed by numpydoc !!