# Copyright 2019 The TensorNetwork Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Callable, List, Text, Type, Sequence
from typing import Union
from tensornetwork.backends import abstract_backend
from tensornetwork.backends.numpy import decompositions
import numpy as np
from tensornetwork.backends.jax import jitted_functions
from functools import partial
import warnings
Tensor = Any
# pylint: disable=abstract-method
_CACHED_MATVECS = {}
_CACHED_FUNCTIONS = {}
class JaxBackend(abstract_backend.AbstractBackend):
"""See abstract_backend.AbstractBackend for documentation."""
def __init__(self, dtype: Optional[np.dtype] = None,
precision: Optional[Text] = None) -> None:
# pylint: disable=global-variable-undefined
global libjax # Jax module
global jnp # jax.numpy module
global jsp # jax.scipy module
super().__init__()
try:
#pylint: disable=import-outside-toplevel
import jax
except ImportError as err:
raise ImportError("Jax not installed, please switch to a different "
"backend or install Jax.") from err
libjax = jax
jnp = libjax.numpy
jsp = libjax.scipy
self.name = "jax"
self._dtype = np.dtype(dtype) if dtype is not None else None
self.jax_precision = precision if precision is not None else libjax.lax.Precision.DEFAULT #pylint: disable=line-too-long
[docs] def tensordot(self, a: Tensor, b: Tensor,
axes: Union[int, Sequence[Sequence[int]]]) -> Tensor:
return jnp.tensordot(a, b, axes, precision=self.jax_precision)
[docs] def reshape(self, tensor: Tensor, shape: Tensor) -> Tensor:
return jnp.reshape(tensor, np.asarray(shape).astype(np.int32))
[docs] def transpose(self, tensor, perm=None) -> Tensor:
return jnp.transpose(tensor, perm)
[docs] def shape_concat(self, values: Tensor, axis: int) -> Tensor:
return np.concatenate(values, axis)
[docs] def slice(self, tensor: Tensor, start_indices: Tuple[int, ...],
slice_sizes: Tuple[int, ...]) -> Tensor:
if len(start_indices) != len(slice_sizes):
raise ValueError("Lengths of start_indices and slice_sizes must be"
"identical.")
return libjax.lax.dynamic_slice(tensor, start_indices, slice_sizes)
def svd(
self,
tensor: Tensor,
pivot_axis: int = -1,
max_singular_values: Optional[int] = None,
max_truncation_error: Optional[float] = None,
relative: Optional[bool] = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return decompositions.svd(
jnp,
tensor,
pivot_axis,
max_singular_values,
max_truncation_error,
relative=relative)
def qr(
self,
tensor: Tensor,
pivot_axis: int = -1,
non_negative_diagonal: bool = False
) -> Tuple[Tensor, Tensor]:
return decompositions.qr(jnp, tensor, pivot_axis, non_negative_diagonal)
def rq(
self,
tensor: Tensor,
pivot_axis: int = -1,
non_negative_diagonal: bool = False
) -> Tuple[Tensor, Tensor]:
return decompositions.rq(jnp, tensor, pivot_axis, non_negative_diagonal)
[docs] def shape_tensor(self, tensor: Tensor) -> Tensor:
return tensor.shape
[docs] def shape_tuple(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
return tensor.shape
[docs] def sparse_shape(self, tensor: Tensor) -> Tuple[Optional[int], ...]:
return self.shape_tuple(tensor)
[docs] def shape_prod(self, values: Tensor) -> Tensor:
return np.prod(values)
[docs] def sqrt(self, tensor: Tensor) -> Tensor:
return jnp.sqrt(tensor)
def convert_to_tensor(self, tensor: Tensor) -> Tensor:
if (not isinstance(tensor, (np.ndarray, jnp.ndarray))
and not jnp.isscalar(tensor)):
raise TypeError(("Expected a `jnp.array`, `np.array` or scalar. "
f"Got {type(tensor)}"))
result = jnp.asarray(tensor)
return result
[docs] def outer_product(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return jnp.tensordot(tensor1, tensor2, 0,
precision=self.jax_precision)
[docs] def einsum(self,
expression: str,
*tensors: Tensor,
optimize: bool = True) -> Tensor:
return jnp.einsum(expression, *tensors, optimize=optimize,
precision=self.jax_precision)
[docs] def norm(self, tensor: Tensor) -> Tensor:
return jnp.linalg.norm(tensor)
def eye(self,
N,
dtype: Optional[np.dtype] = None,
M: Optional[int] = None) -> Tensor:
dtype = dtype if dtype is not None else jnp.float64
return jnp.eye(N, M=M, dtype=dtype)
def ones(self,
shape: Tuple[int, ...],
dtype: Optional[np.dtype] = None) -> Tensor:
dtype = dtype if dtype is not None else jnp.float64
return jnp.ones(shape, dtype=dtype)
def zeros(self,
shape: Tuple[int, ...],
dtype: Optional[np.dtype] = None) -> Tensor:
dtype = dtype if dtype is not None else jnp.float64
return jnp.zeros(shape, dtype=dtype)
[docs] def randn(self,
shape: Tuple[int, ...],
dtype: Optional[np.dtype] = None,
seed: Optional[int] = None) -> Tensor:
if not seed:
seed = np.random.randint(0, 2**63)
key = libjax.random.PRNGKey(seed)
dtype = dtype if dtype is not None else np.dtype(np.float64)
def cmplx_randn(complex_dtype, real_dtype):
real_dtype = np.dtype(real_dtype)
complex_dtype = np.dtype(complex_dtype)
key_2 = libjax.random.PRNGKey(seed + 1)
real_part = libjax.random.normal(key, shape, dtype=real_dtype)
complex_part = libjax.random.normal(key_2, shape, dtype=real_dtype)
unit = (
np.complex64(1j)
if complex_dtype == np.dtype(np.complex64) else np.complex128(1j))
return real_part + unit * complex_part
if np.dtype(dtype) is np.dtype(jnp.complex128):
return cmplx_randn(dtype, jnp.float64)
if np.dtype(dtype) is np.dtype(jnp.complex64):
return cmplx_randn(dtype, jnp.float32)
return libjax.random.normal(key, shape).astype(dtype)
[docs] def eigs(self, #pylint: disable=arguments-differ
A: Callable,
args: Optional[List] = None,
initial_state: Optional[Tensor] = None,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 50,
numeig: int = 6,
tol: float = 1E-8,
which: Text = 'LR',
maxiter: int = 20) -> Tuple[Tensor, List]:
"""
Implicitly restarted Arnoldi method for finding the lowest
eigenvector-eigenvalue pairs of a linear operator `A`.
`A` is a function implementing the matrix-vector
product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigs`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
import jax
import numpy as np
def A(H,x):
return jax.np.dot(H,x)
for n in range(100):
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigs(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
import jax
import numpy as np
for n in range(100):
def A(H,x):
return jax.np.dot(H,x)
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigs(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
args: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
initial_state: An initial vector for the algorithm. If `None`,
a random initial `Tensor` is created using the `backend.randn` method
shape: The shape of the input-dimension of `A`.
dtype: The dtype of the input `A`. If no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The number of eigenvector-eigenvalue pairs to be computed.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point. To increase
precision the caller can either increase `maxiter` or `num_krylov_vecs`.
which: Flag for targetting different types of eigenvalues. Currently
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
(larges magnitude).
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
equivalent to a simple Arnoldi method.
Returns:
(eigvals, eigvecs)
eigvals: A list of `numeig` eigenvalues
eigvecs: A list of `numeig` eigenvectors
"""
if args is None:
args = []
if which not in ('LR', 'LM'):
raise ValueError(f'which = {which} is currently not supported.')
if numeig > num_krylov_vecs:
raise ValueError('`num_krylov_vecs` >= `numeig` required!')
if initial_state is None:
if (shape is None) or (dtype is None):
raise ValueError("if no `initial_state` is passed, then `shape` and"
"`dtype` have to be provided")
initial_state = self.randn(shape, dtype)
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
raise TypeError("Expected a `jax.array`. Got {}".format(
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = libjax.tree_util.Partial(libjax.jit(A))
if "imp_arnoldi" not in _CACHED_FUNCTIONS:
imp_arnoldi = jitted_functions._implicitly_restarted_arnoldi(libjax)
_CACHED_FUNCTIONS["imp_arnoldi"] = imp_arnoldi
eta, U, numits = _CACHED_FUNCTIONS["imp_arnoldi"](_CACHED_MATVECS[A], args,
initial_state,
num_krylov_vecs, numeig,
which, tol, maxiter,
self.jax_precision)
if numeig > numits:
warnings.warn(
f"Arnoldi terminated early after numits = {numits}"
f" < numeig = {numeig} steps. For this value of `numeig `"
f"the routine will return spurious eigenvalues of value 0.0."
f"Use a smaller value of numeig, or a smaller value for `tol`")
return eta, U
[docs] def eigsh(
self, #pylint: disable=arguments-differ
A: Callable,
args: Optional[List] = None,
initial_state: Optional[Tensor] = None,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 50,
numeig: int = 6,
tol: float = 1E-8,
which: Text = 'SA',
maxiter: int = 20) -> Tuple[Tensor, List]:
"""
Implicitly restarted Lanczos method for finding the lowest
eigenvector-eigenvalue pairs of a symmetric (hermitian) linear operator `A`.
`A` is a function implementing the matrix-vector
product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigsh`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
import jax
import numpy as np
def A(H,x):
return jax.np.dot(H,x)
for n in range(100):
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigsh(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
import jax
import numpy as np
for n in range(100):
def A(H,x):
return jax.np.dot(H,x)
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigsh(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
arsg: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
initial_state: An initial vector for the algorithm. If `None`,
a random initial `Tensor` is created using the `backend.randn` method
shape: The shape of the input-dimension of `A`.
dtype: The dtype of the input `A`. If no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The number of eigenvector-eigenvalue pairs to be computed.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point. To increase
precision the caller can either increase `maxiter` or `num_krylov_vecs`.
which: Flag for targetting different types of eigenvalues. Currently
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
(larges magnitude).
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
equivalent to a simple Arnoldi method.
Returns:
(eigvals, eigvecs)
eigvals: A list of `numeig` eigenvalues
eigvecs: A list of `numeig` eigenvectors
"""
if args is None:
args = []
if which not in ('SA', 'LA', 'LM'):
raise ValueError(f'which = {which} is currently not supported.')
if numeig > num_krylov_vecs:
raise ValueError('`num_krylov_vecs` >= `numeig` required!')
if initial_state is None:
if (shape is None) or (dtype is None):
raise ValueError("if no `initial_state` is passed, then `shape` and"
"`dtype` have to be provided")
initial_state = self.randn(shape, dtype)
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
raise TypeError("Expected a `jax.array`. Got {}".format(
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = libjax.tree_util.Partial(libjax.jit(A))
if "imp_lanczos" not in _CACHED_FUNCTIONS:
imp_lanczos = jitted_functions._implicitly_restarted_lanczos(libjax)
_CACHED_FUNCTIONS["imp_lanczos"] = imp_lanczos
eta, U, numits = _CACHED_FUNCTIONS["imp_lanczos"](_CACHED_MATVECS[A], args,
initial_state,
num_krylov_vecs, numeig,
which, tol, maxiter,
self.jax_precision)
if numeig > numits:
warnings.warn(
f"Arnoldi terminated early after numits = {numits}"
f" < numeig = {numeig} steps. For this value of `numeig `"
f"the routine will return spurious eigenvalues of value 0.0."
f"Use a smaller value of numeig, or a smaller value for `tol`")
return eta, U
[docs] def eigsh_lanczos(
self,
A: Callable,
args: Optional[List[Tensor]] = None,
initial_state: Optional[Tensor] = None,
shape: Optional[Tuple] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 20,
numeig: int = 1,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 10,
reorthogonalize: Optional[bool] = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of a hermitian linear operator `A`. `A` is a function implementing
the matrix-vector product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigsh_lanczos`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
import jax
import numpy as np
def A(H,x):
return jax.np.dot(H,x)
for n in range(100):
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigsh_lanczos(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
import jax
import numpy as np
for n in range(100):
def A(H,x):
return jax.np.dot(H,x)
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigsh_lanczos(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
arsg: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
initial_state: An initial vector for the Lanczos algorithm. If `None`,
a random initial `Tensor` is created using the `backend.randn` method
shape: The shape of the input-dimension of `A`.
dtype: The dtype of the input `A`. If no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The number of eigenvector-eigenvalue pairs to be computed.
If `numeig > 1`, `reorthogonalize` has to be `True`.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point.
To increase precision the caller can increase `num_krylov_vecs`.
delta: Stopping criterion for Lanczos iteration.
If a Krylov vector :math: `x_n` has an L2 norm
:math:`\\lVert x_n\\rVert < delta`, the iteration
is stopped. It means that an (approximate) invariant subspace has
been found.
ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations
to check convergence. This has currently no effect for the jax backend,
but may be added at a later point.
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
explicit orthogonalization (more costly than `reorthogonalize=False`)
Returns:
(eigvals, eigvecs)
eigvals: A jax-array containing `numeig` lowest eigenvalues
eigvecs: A list of `numeig` lowest eigenvectors
"""
if args is None:
args = []
if num_krylov_vecs < numeig:
raise ValueError('`num_krylov_vecs` >= `numeig` required!')
if numeig > 1 and not reorthogonalize:
raise ValueError(
"Got numeig = {} > 1 and `reorthogonalize = False`. "
"Use `reorthogonalize=True` for `numeig > 1`".format(numeig))
if initial_state is None:
if (shape is None) or (dtype is None):
raise ValueError("if no `initial_state` is passed, then `shape` and"
"`dtype` have to be provided")
initial_state = self.randn(shape, dtype)
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
raise TypeError("Expected a `jax.array`. Got {}".format(
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = libjax.tree_util.Partial(A)
if "eigsh_lanczos" not in _CACHED_FUNCTIONS:
eigsh_lanczos = jitted_functions._generate_jitted_eigsh_lanczos(libjax)
_CACHED_FUNCTIONS["eigsh_lanczos"] = eigsh_lanczos
eigsh_lanczos = _CACHED_FUNCTIONS["eigsh_lanczos"]
eta, U, numits = eigsh_lanczos(_CACHED_MATVECS[A], args, initial_state,
num_krylov_vecs, numeig, delta,
reorthogonalize, self.jax_precision)
if numeig > numits:
warnings.warn(
f"Lanczos terminated early after numits = {numits}"
f" < numeig = {numeig} steps. For this value of `numeig `"
f"the routine will return spurious eigenvalues of value 0.0."
f"Use a smaller value of numeig, or a smaller value for `tol`")
return eta, U
def _gmres(self,
A_mv: Callable,
b: Tensor,
A_args: List,
A_kwargs: dict,
x0: Tensor,
tol: float,
atol: float,
num_krylov_vectors: int,
maxiter: int,
M: Optional[Callable] = None) -> Tuple[Tensor, int]:
""" GMRES solves the linear system A @ x = b for x given a vector `b` and
a general (not necessarily symmetric/Hermitian) linear operator `A`.
As a Krylov method, GMRES does not require a concrete matrix representation
of the n by n `A`, but only a function
`vector1 = A_mv(vector0, *A_args, **A_kwargs)`
prescribing a one-to-one linear map from vector0 to vector1 (that is,
A must be square, and thus vector0 and vector1 the same size). If `A` is a
dense matrix, or if it is a symmetric/Hermitian operator, a different
linear solver will usually be preferable.
GMRES works by first constructing the Krylov basis
K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then
solving a certain dense linear system K @ q0 = q1 from whose solution x can
be approximated. For `num_krylov_vectors = n` the solution is provably exact
in infinite precision, but the expense is cubic in `num_krylov_vectors` so
one is typically interested in the `num_krylov_vectors << n` case.
The solution can in this case be repeatedly
improved, to a point, by restarting the Arnoldi iterations each time
`num_krylov_vectors` is reached. Unfortunately the optimal parameter choices
balancing expense and accuracy are difficult to predict in advance, so
applying this function requires a degree of experimentation.
In a tensor network code one is typically interested in A_mv implementing
some tensor contraction. This implementation thus allows `b` and `x0` to be
of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.
Reshaping to and from a matrix problem is handled internally.
The Jax backend version of GMRES uses a homemade implementation that, for
now, is suboptimal for num_krylov_vecs ~ b.size.
For the same reason as described in eigsh_lancsoz, the function A_mv
should be Jittable (or already Jitted) and, if at all possible, defined
only once at the global scope. A new compilation will be triggered each
time an A_mv with a new function signature is passed in, even if the
'new' function is identical to the old one (function identity is
undecidable).
Args:
A_mv : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and
`v` have the same shape.
b : The `b` in `A @ x = b`; it should be of the shape `A_mv`
operates on.
A_args : Positional arguments to `A_mv`, supplied to this interface
as a list.
Default: None.
A_kwargs : In the other backends, keyword arguments to `A_mv`, supplied
as a dictionary. However, the Jax backend does not support
A_mv accepting
keyword arguments since this causes problems with Jit.
Therefore, an error is thrown if A_kwargs is specified.
Default: None.
x0 : An optional guess solution. Zeros are used by default.
If `x0` is supplied, its shape and dtype must match those of
`b`, or an
error will be thrown.
Default: zeros.
tol, atol: Solution tolerance to achieve,
norm(residual) <= max(tol*norm(b), atol).
Default: tol=1E-05
atol=tol
num_krylov_vectors
: Size of the Krylov space to build at each restart.
Expense is cubic in this parameter.
Default: 20.
maxiter : The Krylov space will be repeatedly rebuilt up to this many
times. Large values of this argument
should be used only with caution, since especially for nearly
symmetric matrices and small `num_krylov_vectors` convergence
might well freeze at a value significantly larger than `tol`.
Default: 1
M : Inverse of the preconditioner of A; see the docstring for
`scipy.sparse.linalg.gmres`. This is unsupported in the Jax
backend, and NotImplementedError will be raised if it is
supplied.
Default: None.
Raises:
ValueError: -if `x0` is supplied but its shape differs from that of `b`.
-if num_krylov_vectors <= 0.
-if tol or atol was negative.
NotImplementedError: - If M is supplied.
- If A_kwargs is supplied.
TypeError: -if the dtype of `x0` and `b` are mismatching.
Returns:
x : The converged solution. It has the same shape as `b`.
info : 0 if convergence was achieved, the number of restarts otherwise.
"""
if M is not None:
raise NotImplementedError("M is not supported by the Jax backend.")
if A_kwargs:
raise NotImplementedError("A_kwargs is not supported by the Jax backend.")
if A_mv not in _CACHED_MATVECS:
@libjax.tree_util.Partial
def matrix_matvec(x, *args):
x = x.reshape(b.shape)
result = A_mv(x, *args)
return result.ravel()
_CACHED_MATVECS[A_mv] = matrix_matvec
if "gmres" not in _CACHED_FUNCTIONS:
_CACHED_FUNCTIONS["gmres"] = jitted_functions.gmres_wrapper(libjax)
gmres_m = _CACHED_FUNCTIONS["gmres"].gmres_m
x, _, n_iter, converged = gmres_m(_CACHED_MATVECS[A_mv], A_args, b.ravel(),
x0, tol, atol, num_krylov_vectors,
maxiter, self.jax_precision)
if converged:
info = 0
else:
info = n_iter
x = self.reshape(x, b.shape)
return x, info
[docs] def conj(self, tensor: Tensor) -> Tensor:
return jnp.conj(tensor)
def eigh(self, matrix: Tensor) -> Tuple[Tensor, Tensor]:
return jnp.linalg.eigh(matrix)
[docs] def addition(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 + tensor2
[docs] def subtraction(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 - tensor2
[docs] def multiply(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 * tensor2
[docs] def divide(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 / tensor2
[docs] def index_update(self, tensor: Tensor, mask: Tensor,
assignee: Tensor) -> Tensor:
return libjax.ops.index_update(tensor, mask, assignee)
[docs] def inv(self, matrix: Tensor) -> Tensor:
if len(matrix.shape) > 2:
raise ValueError("input to numpy backend method `inv` has shape {}."
" Only matrices are supported.".format(matrix.shape))
return jnp.linalg.inv(matrix)
[docs] def broadcast_right_multiplication(self, tensor1: Tensor,
tensor2: Tensor) -> Tensor:
if len(tensor2.shape) != 1:
raise ValueError("only order-1 tensors are allowed for `tensor2`,"
" found `tensor2.shape = {}`".format(tensor2.shape))
return tensor1 * tensor2
[docs] def broadcast_left_multiplication(self, tensor1: Tensor,
tensor2: Tensor) -> Tensor:
if len(tensor1.shape) != 1:
raise ValueError("only order-1 tensors are allowed for `tensor1`,"
" found `tensor1.shape = {}`".format(tensor1.shape))
t1_broadcast_shape = self.shape_concat(
[self.shape_tensor(tensor1), [1] * (len(tensor2.shape) - 1)], axis=-1)
return tensor2 * self.reshape(tensor1, t1_broadcast_shape)
def sin(self, tensor: Tensor) -> Tensor:
return jnp.sin(tensor)
def cos(self, tensor: Tensor) -> Tensor:
return jnp.cos(tensor)
[docs] def exp(self, tensor: Tensor) -> Tensor:
return jnp.exp(tensor)
[docs] def log(self, tensor: Tensor) -> Tensor:
return jnp.log(tensor)
def expm(self, matrix: Tensor) -> Tensor:
if len(matrix.shape) != 2:
raise ValueError("input to numpy backend method `expm` has shape {}."
" Only matrices are supported.".format(matrix.shape))
if matrix.shape[0] != matrix.shape[1]:
raise ValueError("input to numpy backend method `expm` only supports"
" N*N matrix, {x}*{y} matrix is given".format(
x=matrix.shape[0], y=matrix.shape[1]))
# pylint: disable=no-member
return jsp.linalg.expm(matrix)
def jit(self, fun: Callable, *args: List, **kwargs: dict) -> Callable:
return libjax.jit(fun, *args, **kwargs)
[docs] def sum(self,
tensor: Tensor,
axis: Optional[Sequence[int]] = None,
keepdims: bool = False) -> Tensor:
return jnp.sum(tensor, axis=axis, keepdims=keepdims)
[docs] def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):
raise ValueError("inputs to `matmul` have to be tensors of order > 1,")
return jnp.matmul(tensor1, tensor2, precision=self.jax_precision)
[docs] def diagonal(self, tensor: Tensor, offset: int = 0, axis1: int = -2,
axis2: int = -1) -> Tensor:
"""Return specified diagonals.
If tensor is 2-D, returns the diagonal of tensor with the given offset,
i.e., the collection of elements of the form a[i, i+offset].
If a has more than two dimensions, then the axes specified by
axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is
returned. The shape of the resulting array can be determined by removing
axis1 and axis2 and appending an index to the right equal to the size of the
resulting diagonals.
This function only extracts diagonals. If you
wish to create diagonal matrices from vectors, use diagflat.
Args:
tensor: A tensor.
offset: Offset of the diagonal from the main diagonal.
axis1, axis2: Axis to be used as the first/second axis of the 2D
sub-arrays from which the diagonals should be taken.
Defaults to second last/last axis.
Returns:
array_of_diagonals: A dim = min(1, tensor.ndim - 2) tensor storing
the batched diagonals.
"""
if axis1 == axis2:
raise ValueError("axis1, axis2 cannot be equal.")
return jnp.diagonal(tensor, offset=offset, axis1=axis1, axis2=axis2)
[docs] def diagflat(self, tensor: Tensor, k: int = 0) -> Tensor:
""" Flattens tensor and creates a new matrix of zeros with its elements
on the k'th diagonal.
Args:
tensor: A tensor.
k : The diagonal upon which to place its elements.
Returns:
tensor: A new tensor with all zeros save the specified diagonal.
"""
return jnp.diag(jnp.ravel(tensor), k=k)
[docs] def trace(self, tensor: Tensor, offset: int = 0, axis1: int = -2,
axis2: int = -1) -> Tensor:
"""Return summed entries along diagonals.
If tensor is 2-D, the sum is over the
diagonal of tensor with the given offset,
i.e., the collection of elements of the form a[i, i+offset].
If a has more than two dimensions, then the axes specified by
axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is
summed.
Args:
tensor: A tensor.
offset: Offset of the diagonal from the main diagonal.
axis1, axis2: Axis to be used as the first/second axis of the 2D
sub-arrays from which the diagonals should be taken.
Defaults to second last/last axis.
Returns:
array_of_diagonals: The batched summed diagonals.
"""
if axis1 == axis2:
raise ValueError("axis1, axis2 cannot be equal.")
return jnp.trace(tensor, offset=offset, axis1=axis1, axis2=axis2)
def abs(self, tensor: Tensor) -> Tensor:
"""
Returns the elementwise absolute value of tensor.
Args:
tensor: An input tensor.
Returns:
tensor: Its elementwise absolute value.
"""
return jnp.abs(tensor)
[docs] def sign(self, tensor: Tensor) -> Tensor:
"""
Returns an elementwise tensor with entries
y[i] = 1, 0, -1 where tensor[i] > 0, == 0, and < 0 respectively.
For complex input the behaviour of this function may depend on the backend.
The Jax backend version returns y[i] = x[i]/sqrt(x[i]^2).
Args:
tensor: The input tensor.
"""
return jnp.sign(tensor)
[docs] def item(self, tensor):
return tensor.item()
[docs] def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor:
"""
Returns the power of tensor a to the value of b.
In the case b is a tensor, then the power is by element
with a as the base and b as the exponent.
In the case b is a scalar, then the power of each value in a
is raised to the exponent of b.
Args:
a: The tensor that contains the base.
b: The tensor that contains the exponent or a single scalar.
"""
return jnp.power(a, b)
[docs] def eps(self, dtype: Type[np.number]) -> float:
"""
Return machine epsilon for given `dtype`
Args:
dtype: A dtype.
Returns:
float: Machine epsilon.
"""
return jnp.finfo(dtype).eps