Source code for tensorcircuit.backends.jax_backend

"""
Backend magic inherited from tensornetwork: jax backend
"""
# pylint: disable=invalid-name

from functools import partial
import logging
import warnings
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import numpy as np
from scipy.sparse import coo_matrix
import tensornetwork
from tensornetwork.backends.jax import jax_backend
from .abstract_backend import ExtendedBackend

logger = logging.getLogger(__name__)


dtypestr: str
Tensor = Any
PRNGKeyArray = Any  # libjax.random.PRNGKeyArray
pytree = Any

libjax: Any
jnp: Any
jsp: Any
optax: Any


[docs]class optax_optimizer: # the behavior of this optimizer abstraction with jit is not guranteed
[docs] def __init__(self, optimizer: Any) -> None: self.optimizer = optimizer self.state = None
[docs] def update(self, grads: pytree, params: pytree) -> pytree: if self.state is None: self.state = self.optimizer.init(params) updates, self.state = self.optimizer.update(grads, self.state) params = optax.apply_updates(params, updates) return params
def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor: if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor): raise TypeError( ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}") ) result = jnp.asarray(tensor) return result def _svd_jax( self: Any, tensor: Tensor, pivot_axis: int = -1, max_singular_values: Optional[int] = None, max_truncation_error: Optional[float] = None, relative: Optional[bool] = False, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: from .jax_ops import adaware_svd_jit as adaware_svd left_dims = tensor.shape[:pivot_axis] right_dims = tensor.shape[pivot_axis:] tensor = jnp.reshape(tensor, [np.prod(left_dims), np.prod(right_dims)]) u, s, vh = adaware_svd(tensor) if max_singular_values is None: max_singular_values = jnp.size(s) if max_truncation_error is not None: # Cumulative norms of singular values in ascending order. trunc_errs = jnp.sqrt(jnp.cumsum(jnp.square(s[::-1]))) # If relative is true, rescale max_truncation error with the largest # singular value to yield the absolute maximal truncation error. if relative: abs_max_truncation_error = max_truncation_error * s[0] else: abs_max_truncation_error = max_truncation_error # We must keep at least this many singular values to ensure the # truncation error is <= abs_max_truncation_error. num_sing_vals_err = jnp.count_nonzero( (trunc_errs > abs_max_truncation_error).astype(jnp.int32) ) else: num_sing_vals_err = max_singular_values num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) s = s.astype(tensor.dtype) s_rest = s[num_sing_vals_keep:] s = s[:num_sing_vals_keep] u = u[:, :num_sing_vals_keep] vh = vh[:num_sing_vals_keep, :] dim_s = s.shape[0] u = jnp.reshape(u, list(left_dims) + [dim_s]) vh = jnp.reshape(vh, [dim_s] + list(right_dims)) return u, s, vh, s_rest def _qr_jax( self: Any, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the QR decomposition of a tensor. See tensornetwork.backends.tensorflow.decompositions for details. """ from .jax_ops import adaware_qr_jit as adaware_qr left_dims = tensor.shape[:pivot_axis] right_dims = tensor.shape[pivot_axis:] tensor = jnp.reshape(tensor, [np.prod(left_dims), np.prod(right_dims)]) q, r = adaware_qr(tensor) if non_negative_diagonal: phases = jnp.sign(jnp.diagonal(r)) q = q * phases r = phases.conj()[:, None] * r center_dim = q.shape[1] q = jnp.reshape(q, list(left_dims) + [center_dim]) r = jnp.reshape(r, [center_dim] + list(right_dims)) return q, r def _rq_jax( self: Any, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the RQ (reversed QR) decomposition of a tensor. See tensornetwork.backends.tensorflow.decompositions for details. """ from .jax_ops import adaware_qr_jit as adaware_qr left_dims = tensor.shape[:pivot_axis] right_dims = tensor.shape[pivot_axis:] tensor = jnp.reshape(tensor, [np.prod(left_dims), np.prod(right_dims)]) q, r = adaware_qr(jnp.conj(jnp.transpose(tensor))) if non_negative_diagonal: phases = jnp.sign(jnp.diagonal(r)) q = q * phases r = phases.conj()[:, None] * r r, q = jnp.conj(jnp.transpose(r)), jnp.conj(jnp.transpose(q)) # M=r*q at this point center_dim = r.shape[1] r = jnp.reshape(r, list(left_dims) + [center_dim]) q = jnp.reshape(q, [center_dim] + list(right_dims)) return r, q def _eigh_jax(self: Any, tensor: Tensor) -> Tensor: from .jax_ops import adaware_eigh_jit as adaware_eigh return adaware_eigh(tensor) tensornetwork.backends.jax.jax_backend.JaxBackend.convert_to_tensor = ( _convert_to_tensor_jax ) tensornetwork.backends.jax.jax_backend.JaxBackend.svd = _svd_jax tensornetwork.backends.jax.jax_backend.JaxBackend.qr = _qr_jax tensornetwork.backends.jax.jax_backend.JaxBackend.rq = _rq_jax tensornetwork.backends.jax.jax_backend.JaxBackend.eigh = _eigh_jax
[docs]class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore """ See the original backend API at `jax backend <https://github.com/google/TensorNetwork/blob/master/tensornetwork/backends/jax/jax_backend.py>`_ """ # Jax doesn't support 64bit dtype, unless claim # ``from jax.config import config``` # ``config.update("jax_enable_x64", True)`` # at very beginning, i.e. before import tensorcircuit
[docs] def __init__(self) -> None: global libjax # Jax module global jnp # jax.numpy module global jsp # jax.scipy module global sparse # jax.experimental.sparse global optax # optax super(JaxBackend, self).__init__() try: import jax except ImportError: raise ImportError( "Jax not installed, please switch to a different " "backend or install Jax." ) from jax.experimental import sparse import jax.scipy try: import optax except ImportError: logger.warning( "optax not installed, `optimizer` from jax backend cannot work" ) libjax = jax jnp = libjax.numpy jsp = libjax.scipy self.name = "jax"
# it is already child of numpy backend, and self.np = self.jax.np
[docs] def eye( self, N: int, dtype: Optional[str] = None, M: Optional[int] = None ) -> Tensor: if dtype is None: dtype = dtypestr r = jnp.eye(N, M=M) return self.cast(r, dtype)
[docs] def ones(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = jnp.ones(shape) return self.cast(r, dtype)
[docs] def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = jnp.zeros(shape) return self.cast(r, dtype)
[docs] def copy(self, tensor: Tensor) -> Tensor: return jnp.array(tensor, copy=True)
[docs] def convert_to_tensor(self, tensor: Tensor) -> Tensor: result = jnp.asarray(tensor) return result
[docs] def abs(self, a: Tensor) -> Tensor: return jnp.abs(a)
[docs] def sin(self, a: Tensor) -> Tensor: return jnp.sin(a)
[docs] def cos(self, a: Tensor) -> Tensor: return jnp.cos(a)
[docs] def acos(self, a: Tensor) -> Tensor: return jnp.arccos(a)
[docs] def acosh(self, a: Tensor) -> Tensor: return jnp.arccosh(a)
[docs] def asin(self, a: Tensor) -> Tensor: return jnp.arcsin(a)
[docs] def asinh(self, a: Tensor) -> Tensor: return jnp.arcsinh(a)
[docs] def atan(self, a: Tensor) -> Tensor: return jnp.arctan(a)
[docs] def atan2(self, y: Tensor, x: Tensor) -> Tensor: return jnp.arctan2(y, x)
[docs] def atanh(self, a: Tensor) -> Tensor: return jnp.arctanh(a)
[docs] def cosh(self, a: Tensor) -> Tensor: return jnp.cosh(a)
[docs] def tan(self, a: Tensor) -> Tensor: return jnp.tan(a)
[docs] def tanh(self, a: Tensor) -> Tensor: return jnp.tanh(a)
[docs] def sinh(self, a: Tensor) -> Tensor: return jnp.sinh(a)
[docs] def size(self, a: Tensor) -> Tensor: return jnp.size(a)
[docs] def eigvalsh(self, a: Tensor) -> Tensor: return jnp.linalg.eigvalsh(a)
[docs] def kron(self, a: Tensor, b: Tensor) -> Tensor: return jnp.kron(a, b)
[docs] def numpy(self, a: Tensor) -> Tensor: if self.is_sparse(a): return coo_matrix( (a.data, (a.indices[:, 0], a.indices[:, 1])), shape=a.shape ) return np.array(a)
[docs] def i(self, dtype: Any = None) -> Tensor: if not dtype: dtype = npdtype # type: ignore if isinstance(dtype, str): dtype = getattr(jnp, dtype) return jnp.array(1j, dtype=dtype)
[docs] def det(self, a: Tensor) -> Tensor: return jnp.linalg.det(a)
[docs] def schur(self, a: Tensor, output: str = "real") -> Tuple[Tensor, Tensor]: return jsp.linalg.schur(a, output=output) # type: ignore
[docs] def real(self, a: Tensor) -> Tensor: return jnp.real(a)
[docs] def imag(self, a: Tensor) -> Tensor: return jnp.imag(a)
[docs] def dtype(self, a: Tensor) -> str: return a.dtype.__str__() # type: ignore
[docs] def cast(self, a: Tensor, dtype: str) -> Tensor: with warnings.catch_warnings(): warnings.simplefilter("ignore", np.ComplexWarning) if isinstance(dtype, str): return a.astype(getattr(jnp, dtype)) return a.astype(dtype)
[docs] def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor: if stop is None: return jnp.arange(start=0, stop=start, step=step) return jnp.arange(start=start, stop=stop, step=step)
[docs] def mod(self, x: Tensor, y: Tensor) -> Tensor: return jnp.mod(x, y)
[docs] def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return jnp.right_shift(x, y)
[docs] def left_shift(self, x: Tensor, y: Tensor) -> Tensor: return jnp.left_shift(x, y)
[docs] def expm(self, a: Tensor) -> Tensor: return jsp.linalg.expm(a)
# currently expm in jax doesn't support AD, it will raise an AssertError, # see https://github.com/google/jax/issues/2645
[docs] def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return jnp.stack(a, axis=axis)
[docs] def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return jnp.concatenate(a, axis=axis)
[docs] def tile(self, a: Tensor, rep: Tensor) -> Tensor: return jnp.tile(a, rep)
[docs] def mean( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False, ) -> Tensor: return jnp.mean(a, axis=axis, keepdims=keepdims)
[docs] def std( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False ) -> Tensor: return jnp.std(a, axis=axis, keepdims=keepdims)
[docs] def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return jnp.min(a, axis=axis)
[docs] def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return jnp.max(a, axis=axis)
[docs] def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return jnp.argmax(a, axis=axis)
[docs] def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return jnp.argmin(a, axis=axis)
[docs] def unique_with_counts( # type: ignore self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None ) -> Tuple[Tensor, Tensor]: return jnp.unique(a, return_counts=True, size=size, fill_value=fill_value) # type: ignore
[docs] def sigmoid(self, a: Tensor) -> Tensor: return libjax.nn.sigmoid(a)
[docs] def relu(self, a: Tensor) -> Tensor: return libjax.nn.relu(a)
[docs] def softmax(self, a: Sequence[Tensor], axis: Optional[int] = None) -> Tensor: return libjax.nn.softmax(a, axis=axis)
[docs] def onehot(self, a: Tensor, num: int) -> Tensor: return libjax.nn.one_hot(a, num)
[docs] def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return jnp.cumsum(a, axis)
[docs] def is_tensor(self, a: Any) -> bool: if not isinstance(a, jnp.ndarray): return False # isinstance(np.eye(1), jax.numpy.ndarray) = True! if getattr(a, "_value", None) is not None: return True return False
[docs] def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore return jsp.linalg.solve(A, b, assume_a=assume_a)
[docs] def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: if not self.is_tensor(a): a = self.convert_to_tensor(a) if not self.is_tensor(v): v = self.convert_to_tensor(v) return jnp.searchsorted(a, v, side)
[docs] def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any: return libjax.tree_map(f, *pytrees)
[docs] def tree_flatten(self, pytree: Any) -> Tuple[Any, Any]: return libjax.tree_util.tree_flatten(pytree) # type: ignore
[docs] def tree_unflatten(self, treedef: Any, leaves: Any) -> Any: return libjax.tree_util.tree_unflatten(treedef, leaves)
[docs] def from_dlpack(self, a: Any) -> Tensor: import jax.dlpack return jax.dlpack.from_dlpack(a)
[docs] def to_dlpack(self, a: Tensor) -> Any: import jax.dlpack return jax.dlpack.to_dlpack(a)
[docs] def set_random_state( self, seed: Optional[Union[int, PRNGKeyArray]] = None, get_only: bool = False ) -> Any: if seed is None: seed = np.random.randint(42) if isinstance(seed, int): g = libjax.random.PRNGKey(seed) else: g = seed if get_only is False: self.g = g return g
[docs] def random_split(self, key: Any) -> Tuple[Any, Any]: return libjax.random.split(key) # type: ignore
[docs] def implicit_randn( self, shape: Union[int, Sequence[int]] = 1, mean: float = 0, stddev: float = 1, dtype: str = "32", ) -> Tensor: g = getattr(self, "g", None) if g is None: # or getattr(g, "_trace", None) is not None: # avoid random state is set in a jitted function # which call outside jitted regime lead to UnexpectedTracerError # set with _trace is bad, since the function can itself in jit env self.set_random_state() g = getattr(self, "g", None) try: key, subkey = libjax.random.split(g) except libjax.errors.UnexpectedTracerError: self.set_random_state() g = getattr(self, "g", None) key, subkey = libjax.random.split(g) r = self.stateful_randn(subkey, shape, mean, stddev, dtype) self.g = key return r
[docs] def implicit_randu( self, shape: Union[int, Sequence[int]] = 1, low: float = 0, high: float = 1, dtype: str = "32", ) -> Tensor: g = getattr(self, "g", None) if g is None: # set with _trace is bad, since the function can itself in jit env self.set_random_state() g = getattr(self, "g", None) try: key, subkey = libjax.random.split(g) except libjax.errors.UnexpectedTracerError: self.set_random_state() g = getattr(self, "g", None) key, subkey = libjax.random.split(g) r = self.stateful_randu(subkey, shape, low, high, dtype) self.g = key return r
[docs] def implicit_randc( self, a: Union[int, Sequence[int], Tensor], shape: Union[int, Sequence[int]], p: Optional[Union[Sequence[float], Tensor]] = None, ) -> Tensor: g = getattr(self, "g", None) if g is None: self.set_random_state() g = getattr(self, "g", None) try: key, subkey = libjax.random.split(g) except libjax.errors.UnexpectedTracerError: self.set_random_state() g = getattr(self, "g", None) key, subkey = libjax.random.split(g) r = self.stateful_randc(subkey, a, shape, p) self.g = key return r
[docs] def stateful_randn( self, g: PRNGKeyArray, shape: Union[int, Sequence[int]] = 1, mean: float = 0, stddev: float = 1, dtype: str = "32", ) -> Tensor: if isinstance(dtype, str): dtype = dtype[-2:] if isinstance(shape, int): shape = (shape,) if dtype == "32": dtyper = jnp.float32 elif dtype == "64": dtyper = jnp.float64 elif not isinstance(dtype, str): dtyper = dtype r = libjax.random.normal(g, shape=shape, dtype=dtyper) * stddev + mean return r
[docs] def stateful_randu( self, g: PRNGKeyArray, shape: Union[int, Sequence[int]] = 1, low: float = 0, high: float = 1, dtype: str = "32", ) -> Tensor: if isinstance(dtype, str): dtype = dtype[-2:] if isinstance(shape, int): shape = (shape,) if dtype == "32": dtyper = jnp.float32 elif dtype == "64": dtyper = jnp.float64 elif not isinstance(dtype, str): dtyper = dtype r = libjax.random.uniform(g, shape=shape, dtype=dtyper, minval=low, maxval=high) return r
[docs] def stateful_randc( self, g: PRNGKeyArray, a: Union[int, Sequence[int], Tensor], shape: Union[int, Sequence[int]], p: Optional[Union[Sequence[float], Tensor]] = None, ) -> Tensor: if isinstance(shape, int): shape = (shape,) if not self.is_tensor(a): a = jnp.array(a) if p is not None: if not self.is_tensor(p): p = jnp.array(p) return libjax.random.choice(g, a, shape=shape, replace=True, p=p)
[docs] def cond( self, pred: bool, true_fun: Callable[[], Tensor], false_fun: Callable[[], Tensor], ) -> Tensor: return libjax.lax.cond(pred, lambda _: true_fun(), lambda _: false_fun(), None)
[docs] def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Tensor: # branches_null = [lambda _: b() for b in branches] # see https://stackoverflow.com/a/34021333 for weird behavior of lambda with list comprehension branches_null = [lambda _, f=b: f() for b in branches] return libjax.lax.switch(index, branches_null, None)
[docs] def scan( self, f: Callable[[Tensor, Tensor], Tensor], xs: Tensor, init: Tensor ) -> Tensor: def f_jax(*args: Any, **kws: Any) -> Any: r = f(*args, **kws) return r, None carry, _ = libjax.lax.scan(f_jax, init, xs) return carry
[docs] def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor: rank = len(operand.shape) dnums = libjax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=tuple([i for i in range(rank)]), scatter_dims_to_operand_dims=tuple([i for i in range(rank)]), ) r = libjax.lax.scatter(operand, indices, updates, dnums) return r
[docs] def coo_sparse_matrix( self, indices: Tensor, values: Tensor, shape: Tensor ) -> Tensor: return sparse.BCOO((values, indices), shape=shape) # type: ignore
[docs] def sparse_dense_matmul( self, sp_a: Tensor, b: Tensor, ) -> Tensor: return sp_a @ b
[docs] def to_dense(self, sp_a: Tensor) -> Tensor: return sp_a.todense()
[docs] def is_sparse(self, a: Tensor) -> bool: return isinstance(a, sparse.BCOO) # type: ignore
[docs] def device(self, a: Tensor) -> str: dev = a.device() return self._dev2str(dev)
[docs] def device_move(self, a: Tensor, dev: Any) -> Tensor: if isinstance(dev, str): dev = self._str2dev(dev) return libjax.device_put(a, dev)
def _dev2str(self, dev: Any) -> str: if dev.platform == "cpu": return "cpu" if dev.platform == "gpu": return "gpu:" + str(dev.id) raise ValueError("JaxBackend don't support non-GPU/CPU device") def _str2dev(self, str_: str) -> Any: if str_ == "cpu": return libjax.devices("cpu")[0] if str_.startswith("gpu"): _id = int(str_.split(":")[-1]) return libjax.devices("gpu")[_id] raise ValueError("JaxBackend don't support non-GPU/CPU device")
[docs] def stop_gradient(self, a: Tensor) -> Tensor: return libjax.lax.stop_gradient(a)
[docs] def grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Any: return libjax.grad(f, argnums=argnums, has_aux=has_aux)
[docs] def value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: return libjax.value_and_grad(f, argnums=argnums, has_aux=has_aux) # type: ignore
[docs] def jvp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if not isinstance(inputs, (tuple, list)): inputs = (inputs,) if not isinstance(v, (tuple, list)): v = (v,) value, jvpv = libjax.jvp(f, inputs, v) return value, jvpv
[docs] def vjp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if not (isinstance(inputs, list) or isinstance(inputs, tuple)): # one input tensor inputs = [inputs] one_input = True else: one_input = False value, vjpf = libjax.vjp(f, *inputs) if isinstance(v, list): v = tuple(v) vjpv = vjpf(v) if one_input: vjpv = vjpv[0] return value, vjpv
[docs] def jit( self, f: Callable[..., Any], static_argnums: Optional[Union[int, Sequence[int]]] = None, jit_compile: Optional[bool] = None, **kws: Any, ) -> Any: return libjax.jit(f, static_argnums=static_argnums)
[docs] def vmap( self, f: Callable[..., Any], vectorized_argnums: Union[int, Sequence[int]] = 0 ) -> Any: if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) # if vectorized_argnums == (0,): # fast shortcuts # return libjax.vmap(f) def wrapper(*args: Any, **kws: Any) -> Tensor: in_axes = [0 if i in vectorized_argnums else None for i in range(len(args))] # type: ignore return libjax.vmap(f, in_axes, 0)(*args, **kws) return wrapper
# since tf doesn't support general in&out axes options, we don't support them in universal backend
[docs] def vectorized_value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, vectorized_argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) def wrapper( *args: Any, **kws: Any ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: jf = self.value_and_grad(f, argnums=argnums, has_aux=has_aux) in_axes = [0 if i in vectorized_argnums else None for i in range(len(args))] # type: ignore jf = libjax.vmap(jf, in_axes, 0) # jf = self.jit(jf) vs, gs = jf(*args, **kws) if isinstance(argnums, int): argnums_list = [argnums] gs = [gs] else: argnums_list = argnums # type: ignore gs = list(gs) for i, (j, g) in enumerate(zip(argnums_list, gs)): if j not in vectorized_argnums: # type: ignore gs[i] = libjax.tree_map(partial(jnp.sum, axis=0), g) if isinstance(argnums, int): gs = gs[0] else: gs = tuple(gs) return vs, gs return wrapper
# f = self.value_and_grad(f, argnums=argnums) # f = libjax.vmap(f, (0, None), 0) # f = self.jit(f) # return f vvag = vectorized_value_and_grad optimizer = optax_optimizer