Source code for tensorcircuit.backends.tensorflow_backend

"""
Backend magic inherited from tensornetwork: tensorflow backend
"""
# pylint: disable=invalid-name

import os
import re
from functools import reduce, partial
from operator import mul
from typing import Any, Callable, Optional, Sequence, Tuple, Union

from scipy.sparse import coo_matrix
import tensornetwork
from tensornetwork.backends.tensorflow import tensorflow_backend
from .abstract_backend import ExtendedBackend

dtypestr: str
Tensor = Any
RGenerator = Any  # tf.random.Generator
pytree = Any

tf: Any


[docs]class keras_optimizer:
[docs] def __init__(self, optimizer: Any) -> None: self.optimizer = optimizer self.is_initialized = False
# def _apply_gradients(self, grads: Tensor, params: Tensor) -> None: # self.optimizer.apply_gradients([(grads, params)])
[docs] def update(self, grads: pytree, params: pytree) -> pytree: # if not self.is_initialized: # l, treedef = TensorFlowBackend.tree_flatten(None, params) # # https://github.com/tensorflow/tensorflow/issues/58973 # # still breaks tf2.11 # ol = [deepcopy(self.optimizer) for _ in l] # self.optimizer = TensorFlowBackend.tree_unflatten(None, treedef, ol) # self.is_initialized = True # params = TensorFlowBackend.tree_map(None, self._c2v, params) # don't do the () initialization since cache is in upper level of backend_factory grads_l, _ = TensorFlowBackend.tree_flatten(None, grads) params_l, params_def = TensorFlowBackend.tree_flatten(None, params) if not self.is_initialized: self.params_v = [] self.is_variable = [] for p in params_l: if not isinstance(p, tf.Variable): self.params_v.append(tf.Variable(p)) self.is_variable.append(False) else: self.params_v.append(p) self.is_variable.append(True) self.is_initialized = True else: for i, p in enumerate(params_l): if not isinstance(p, tf.Variable): self.params_v[i] = self.params_v[i].assign(p) else: self.params_v[i] = self.params_v[i].assign(p.value()) self.optimizer.apply_gradients(zip(grads_l, self.params_v)) nparams_l = [] for p, flag in zip(self.params_v, self.is_variable): if flag is True: nparams_l.append(p) else: nparams_l.append(p.value()) params = TensorFlowBackend.tree_unflatten(None, params_def, nparams_l) return params
def _tensordot_tf( self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]] ) -> Tensor: return tf.tensordot(a, b, axes) def _outer_product_tf(self: Any, tensor1: Tensor, tensor2: Tensor) -> Tensor: return tf.tensordot(tensor1, tensor2, 0) def _matmul_tf(self: Any, tensor1: Tensor, tensor2: Tensor) -> Tensor: if (len(tensor1.shape) <= 1) or (len(tensor2.shape) <= 1): raise ValueError("inputs to `matmul` have to be a tensors of order > 1,") return tf.matmul(tensor1, tensor2) def _random_choice_tf( g: RGenerator, a: Union[int, Sequence[int], Tensor], shape: Union[int, Sequence[int]], p: Optional[Union[Sequence[float], Tensor]] = None, ) -> Tensor: # only replace=True support, replace=False is not implemented # for stateless random module, tf has corresponding categorical function similar to choice # however, such utility is not implemented with ``tf.random.Generator`` # part of the code below is inspired by corresponding implementation in jax (Apache 2.0) if isinstance(a, int): assert a > 0 a = tf.range(a) if not (isinstance(a, tf.Tensor) or isinstance(a, tf.Variable)): a = tf.constant(a) assert len(a.shape) == 1 if isinstance(shape, int): shape = (shape,) if p is None: dtype = tf.float32 p = tf.ones_like(a) p = tf.cast(p, dtype=dtype) p /= tf.reduce_sum(p) else: if not (isinstance(p, tf.Tensor) or isinstance(p, tf.Variable)): p = tf.constant(p) dtype = p.dtype shape1 = reduce(mul, shape) p_cuml = tf.cumsum(p) r = p_cuml[-1] * (1 - g.uniform([shape1], dtype=dtype)) ind = tf.searchsorted(p_cuml, r) res = tf.gather(a, ind) return tf.reshape(res, shape) def _qr_tf( self: Any, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the QR decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. :Example: If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `q` would have shape (2, 3, 6), and `r` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] Note that the output ordering matches numpy.linalg.svd rather than tf.svd. :param tensor: A tensor to be decomposed. :type tensor: Tensor :param pivot_axis: Where to split the tensor's axes before flattening into a matrix. :type pivot_axis: int, optional :param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix. :type non_negative_diagonal: bool, optional :returns: Q, the left tensor factor, and R, the right tensor factor. :rtype: Tuple[Tensor, Tensor] """ from .tf_ops import tfqr left_dims = tf.shape(tensor)[:pivot_axis] right_dims = tf.shape(tensor)[pivot_axis:] tensor = tf.reshape(tensor, [tf.reduce_prod(left_dims), tf.reduce_prod(right_dims)]) q, r = tfqr(tensor) if non_negative_diagonal: phases = tf.math.sign(tf.linalg.diag_part(r)) q = q * phases r = phases[:, None] * r center_dim = tf.shape(q)[1] q = tf.reshape(q, tf.concat([left_dims, [center_dim]], axis=-1)) r = tf.reshape(r, tf.concat([[center_dim], right_dims], axis=-1)) return q, r def _rq_tf( self: Any, tensor: Tensor, pivot_axis: int = 1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the RQ decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. :Example: If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `r` would have shape (2, 3, 6), and `q` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] Note that the output ordering matches numpy.linalg.svd rather than tf.svd. :param tensor: A tensor to be decomposed. :type tensor: Tensor :param pivot_axis: Where to split the tensor's axes before flattening into a matrix. :type pivot_axis: int, optional :param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix. :type non_negative_diagonal: bool, optional :returns: Q, the left tensor factor, and R, the right tensor factor. :rtype: Tuple[Tensor, Tensor] """ from .tf_ops import tfqr left_dims = tf.shape(tensor)[:pivot_axis] right_dims = tf.shape(tensor)[pivot_axis:] tensor = tf.reshape(tensor, [tf.reduce_prod(left_dims), tf.reduce_prod(right_dims)]) q, r = tfqr(tf.math.conj(tf.transpose(tensor))) if non_negative_diagonal: phases = tf.math.sign(tf.linalg.diag_part(r)) q = q * phases r = phases[:, None] * r r, q = ( tf.math.conj(tf.transpose(r)), tf.math.conj(tf.transpose(q)), ) # M=r*q at this point center_dim = tf.shape(r)[1] r = tf.reshape(r, tf.concat([left_dims, [center_dim]], axis=-1)) q = tf.reshape(q, tf.concat([[center_dim], right_dims], axis=-1)) return r, q def _svd_tf( self: Any, tensor: Tensor, pivot_axis: int = -1, max_singular_values: Optional[int] = None, max_truncation_error: Optional[float] = None, relative: Optional[bool] = False, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the singular value decomposition (SVD) of a tensor. The SVD is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. For example, if `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `u` would have shape (2, 3, 6), `s` would have shape (6), and `vh` would have shape (6, 4, 5). If `max_singular_values` is set to an integer, the SVD is truncated to keep at most this many singular values. If `max_truncation_error > 0`, as many singular values will be truncated as possible, so that the truncation error (the norm of discarded singular values) is at most `max_truncation_error`. If `relative` is set `True` then `max_truncation_err` is understood relative to the largest singular value. If both `max_singular_values` snd `max_truncation_error` are specified, the number of retained singular values will be `min(max_singular_values, nsv_auto_trunc)`, where `nsv_auto_trunc` is the number of singular values that must be kept to maintain a truncation error smaller than `max_truncation_error`. The output consists of three tensors `u, s, vh` such that: ```python u[i1,...,iN, j] * s[j] * vh[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] ``` Note that the output ordering matches numpy.linalg.svd rather than tf.svd. Args: tf: The tensorflow module. tensor: A tensor to be decomposed. pivot_axis: Where to split the tensor's axes before flattening into a matrix. max_singular_values: The number of singular values to keep, or `None` to keep them all. max_truncation_error: The maximum allowed truncation error or `None` to not do any truncation. relative: Multiply `max_truncation_err` with the largest singular value. Returns: u: Left tensor factor. s: Vector of ordered singular values from largest to smallest. vh: Right tensor factor. s_rest: Vector of discarded singular values (length zero if no truncation). """ left_dims = tf.shape(tensor)[:pivot_axis] right_dims = tf.shape(tensor)[pivot_axis:] tensor = tf.reshape(tensor, [tf.reduce_prod(left_dims), tf.reduce_prod(right_dims)]) eps = os.environ.get("TC_BACKENDS_TENSORFLOW_BACKEND__SVD_TF_EPS") if eps is not None: eps = 10 ** (-int(eps)) tensor += eps * tf.ones(tensor.shape, dtype=tensor.dtype) # for numerical stability at least in tf+cpu s, u, v = tf.linalg.svd(tensor) if max_singular_values is None: max_singular_values = tf.size(s, out_type=tf.int64) else: max_singular_values = tf.constant(max_singular_values, dtype=tf.int64) if max_truncation_error is not None: # Cumulative norms of singular values in ascending order. trunc_errs = tf.sqrt(tf.cumsum(tf.square(s), reverse=True)) # If relative is true, rescale max_truncation error with the largest # singular value to yield the absolute maximal truncation error. if relative: abs_max_truncation_error = max_truncation_error * s[0] else: abs_max_truncation_error = max_truncation_error # We must keep at least this many singular values to ensure the # truncation error is <= abs_max_truncation_error. num_sing_vals_err = tf.math.count_nonzero( tf.cast(trunc_errs > abs_max_truncation_error, dtype=tf.int32) ) else: num_sing_vals_err = max_singular_values num_sing_vals_keep = tf.minimum(max_singular_values, num_sing_vals_err) # tf.svd() always returns the singular values as a vector of float{32,64}. # since tf.math_ops.real is automatically applied to s. This causes # s to possibly not be the same dtype as the original tensor, which can cause # issues for later contractions. To fix it, we recast to the original dtype. s = tf.cast(s, tensor.dtype) s_rest = s[num_sing_vals_keep:] s = s[:num_sing_vals_keep] u = u[:, :num_sing_vals_keep] v = v[:, :num_sing_vals_keep] vh = tf.linalg.adjoint(v) dim_s = tf.shape(s)[0] # must use tf.shape (not s.shape) to compile u = tf.reshape(u, tf.concat([left_dims, [dim_s]], axis=-1)) vh = tf.reshape(vh, tf.concat([[dim_s], right_dims], axis=-1)) return u, s, vh, s_rest # temporary hot replace until new version of tensorflow is released, # see issue: https://github.com/google/TensorNetwork/issues/940 # avoid buggy tensordot2 in tensornetwork tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.tensordot = ( _tensordot_tf ) tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.outer_product = ( _outer_product_tf ) tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.matmul = ( _matmul_tf ) tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.qr = _qr_tf tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.rq = _rq_tf tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.svd = _svd_tf
[docs]class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend): # type: ignore """ See the original backend API at `tensorflow backend <https://github.com/google/TensorNetwork/blob/master/tensornetwork/backends/tensorflow/tensorflow_backend.py>`_ """
[docs] def __init__(self) -> None: global tf super(TensorFlowBackend, self).__init__() try: import tensorflow except ImportError: raise ImportError( "Tensorflow not installed, please switch to a " "different backend or install Tensorflow." ) tf = tensorflow tf.sparse.SparseTensor.__add__ = tf.sparse.add self.minor = int(tf.__version__.split(".")[1]) self.name = "tensorflow" logger = tf.get_logger() # .setLevel('ERROR') logger.addFilter(lambda s: not re.match(".*You are casting.*", s.getMessage()))
# ignore casting warning by logger
[docs] def eye( self, N: int, dtype: Optional[str] = None, M: Optional[int] = None ) -> Tensor: if dtype is None: dtype = dtypestr r = tf.eye(num_rows=N, num_columns=M) return self.cast(r, dtype)
[docs] def ones(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = tf.ones(shape=shape) return self.cast(r, dtype)
[docs] def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = tf.zeros(shape=shape) return self.cast(r, dtype)
[docs] def copy(self, a: Tensor) -> Tensor: return tf.identity(a)
[docs] def expm(self, a: Tensor) -> Tensor: return tf.linalg.expm(a)
[docs] def sin(self, a: Tensor) -> Tensor: return tf.math.sin(a)
[docs] def cos(self, a: Tensor) -> Tensor: return tf.math.cos(a)
[docs] def acos(self, a: Tensor) -> Tensor: return tf.math.acos(a)
[docs] def acosh(self, a: Tensor) -> Tensor: return tf.math.acosh(a)
[docs] def asin(self, a: Tensor) -> Tensor: return tf.math.asin(a)
[docs] def asinh(self, a: Tensor) -> Tensor: return tf.math.asinh(a)
[docs] def atan(self, a: Tensor) -> Tensor: return tf.math.atan(a)
[docs] def atan2(self, y: Tensor, x: Tensor) -> Tensor: return tf.math.atan2(y, x)
[docs] def atanh(self, a: Tensor) -> Tensor: return tf.math.atanh(a)
[docs] def cosh(self, a: Tensor) -> Tensor: return tf.math.cosh(a)
[docs] def tan(self, a: Tensor) -> Tensor: return tf.math.tan(a)
[docs] def tanh(self, a: Tensor) -> Tensor: return tf.math.tanh(a)
[docs] def sinh(self, a: Tensor) -> Tensor: return tf.math.sinh(a)
[docs] def size(self, a: Tensor) -> Tensor: return tf.size(a)
[docs] def eigvalsh(self, a: Tensor) -> Tensor: return tf.linalg.eigvalsh(a)
[docs] def dtype(self, a: Tensor) -> str: return a.dtype.__repr__().split(".")[-1] # type: ignore
[docs] def kron(self, a: Tensor, b: Tensor) -> Tensor: # array more than 2d consistency is not guranteed for different backends return tf.reshape( tf.reshape(a, [a.shape[0], 1, a.shape[1], 1]) * tf.reshape(b, [1, b.shape[0], 1, b.shape[1]]), [a.shape[0] * b.shape[0], a.shape[1] * b.shape[1]], )
[docs] def numpy(self, a: Tensor) -> Tensor: if self.is_sparse(a): return coo_matrix( (a.values, (a.indices[:, 0], a.indices[:, 1])), shape=a.get_shape() ) return a.numpy() # only valid in eager mode
[docs] def i(self, dtype: Any = None) -> Tensor: if not dtype: dtype = getattr(tf, dtypestr) if isinstance(dtype, str): dtype = getattr(tf, dtype) return tf.constant(1j, dtype=dtype)
[docs] def det(self, a: Tensor) -> Tensor: return tf.linalg.det(a)
[docs] def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return tf.reduce_min(a, axis=axis)
[docs] def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return tf.reduce_max(a, axis=axis)
[docs] def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return tf.math.argmax(a, axis=axis)
[docs] def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return tf.math.argmin(a, axis=axis)
[docs] def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: r = tf.unique_with_counts(a) order = tf.argsort(r.y) return tf.gather(r.y, order), tf.gather(r.count, order)
[docs] def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return tf.stack(a, axis=axis)
[docs] def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return tf.concat(a, axis=axis)
[docs] def tile(self, a: Tensor, rep: Tensor) -> Tensor: return tf.tile(a, rep)
[docs] def mean( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False, ) -> Tensor: return tf.math.reduce_mean(a, axis=axis, keepdims=keepdims)
[docs] def std( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False ) -> Tensor: return tf.math.reduce_std(a, axis=axis, keepdims=keepdims)
[docs] def sigmoid(self, a: Tensor) -> Tensor: return tf.nn.sigmoid(a)
[docs] def relu(self, a: Tensor) -> Tensor: return tf.nn.relu(a)
[docs] def onehot(self, a: Tensor, num: int) -> Tensor: return tf.one_hot(a, num)
[docs] def softmax(self, a: Sequence[Tensor], axis: Optional[int] = None) -> Tensor: if axis is None: # make the default behavior consistent r = tf.keras.activations.softmax(tf.reshape(a, [1, -1]), axis=axis) return tf.reshape(r, a.shape) # type: ignore return tf.keras.activations.softmax(a, axis=axis)
[docs] def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: a = tf.reshape(a, [-1]) return tf.cumsum(a) else: return tf.cumsum(a, axis)
[docs] def is_tensor(self, a: Any) -> bool: if isinstance(a, tf.Tensor) or isinstance(a, tf.Variable): return True return False
[docs] def abs(self, a: Tensor) -> Tensor: return tf.math.abs(a)
[docs] def real(self, a: Tensor) -> Tensor: return tf.math.real(a)
[docs] def imag(self, a: Tensor) -> Tensor: return tf.math.imag(a)
[docs] def cast(self, a: Tensor, dtype: str) -> Tensor: if isinstance(dtype, str): return tf.cast(a, dtype=getattr(tf, dtype)) return tf.cast(a, dtype=dtype)
[docs] def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor: if stop is None: return tf.range(start=0, limit=start, delta=step) return tf.range(start=start, limit=stop, delta=step)
[docs] def mod(self, x: Tensor, y: Tensor) -> Tensor: return tf.math.mod(x, y)
[docs] def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return tf.bitwise.right_shift(x, y)
[docs] def left_shift(self, x: Tensor, y: Tensor) -> Tensor: return tf.bitwise.left_shift(x, y)
[docs] def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor: if b.shape[-1] == A.shape[-1]: b = b[..., tf.newaxis] vector = True else: vector = False x = tf.linalg.solve(A, b) if vector: return self.reshape(x, x.shape[:-1]) return x
[docs] def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: return tf.searchsorted(a, v, side)
[docs] def from_dlpack(self, a: Any) -> Tensor: return tf.experimental.dlpack.from_dlpack(a)
[docs] def to_dlpack(self, a: Tensor) -> Any: return tf.experimental.dlpack.to_dlpack(a)
# note complex tensor support for dlpack is only available for tf>=2.9
[docs] def set_random_state( self, seed: Optional[Union[int, RGenerator]] = None, get_only: bool = False ) -> Any: if seed is None: g = tf.random.Generator.from_non_deterministic_state() elif isinstance(seed, int): g = tf.random.Generator.from_seed(seed) else: g = seed if get_only is False: self.g = g return g
[docs] def stateful_randn( self, g: RGenerator, shape: Union[int, Sequence[int]] = 1, mean: float = 0, stddev: float = 1, dtype: str = "32", ) -> Tensor: if isinstance(dtype, str): dtype = dtype[-2:] if isinstance(shape, int): shape = (shape,) if dtype == "32": dtyper = tf.float32 elif dtype == "64": dtyper = tf.float64 elif not isinstance(dtype, str): dtyper = dtype return g.normal(shape, mean, stddev, dtype=dtyper)
[docs] def stateful_randu( self, g: RGenerator, shape: Union[int, Sequence[int]] = 1, low: float = 0, high: float = 1, dtype: str = "32", ) -> Tensor: if isinstance(dtype, str): dtype = dtype[-2:] if isinstance(shape, int): shape = (shape,) if dtype == "32": dtyper = tf.float32 elif dtype == "64": dtyper = tf.float64 elif not isinstance(dtype, str): dtyper = dtype return g.uniform(shape, minval=low, maxval=high, dtype=dtyper)
[docs] def stateful_randc( self, g: RGenerator, a: Union[int, Sequence[int], Tensor], shape: Union[int, Sequence[int]], p: Optional[Union[Sequence[float], Tensor]] = None, ) -> Tensor: return _random_choice_tf(g, a, shape, p)
[docs] def gather1d(self, operand: Tensor, indices: Tensor) -> Tensor: return tf.gather(operand, indices)
[docs] def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor: return tf.tensor_scatter_nd_update(operand, indices, updates)
[docs] def coo_sparse_matrix( self, indices: Tensor, values: Tensor, shape: Tensor ) -> Tensor: return tf.SparseTensor(indices=indices, values=values, dense_shape=shape)
[docs] def sparse_dense_matmul( self, sp_a: Tensor, b: Tensor, ) -> Tensor: return tf.sparse.sparse_dense_matmul(sp_a, b)
def _densify(self) -> Tensor: @partial(self.jit, jit_compile=True) def densify(sp_a: Tensor) -> Tensor: return tf.sparse.to_dense(sp_a) return densify
[docs] def to_dense(self, sp_a: Tensor) -> Tensor: return self._densify()(sp_a)
# very weirdly, at least for cpu, tf.sparse.to_dense only works within tf.function(jit_compile=True) # and will fail with tf.function or bare function # on the contrary, tf.sparse.sparse_dense_matmul only fails within tf.function(jit_compile=True) # and works well with tf.function or bare function...
[docs] def is_sparse(self, a: Tensor) -> bool: return isinstance(a, tf.SparseTensor)
[docs] def cond( self, pred: bool, true_fun: Callable[[], Tensor], false_fun: Callable[[], Tensor], ) -> Tensor: return tf.cond(pred, true_fun, false_fun)
[docs] def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Tensor: return tf.switch_case(index, branches)
[docs] def scan( self, f: Callable[[Tensor, Tensor], Tensor], xs: Tensor, init: Tensor ) -> Tensor: return tf.scan(f, xs, init)[-1]
[docs] def device(self, a: Tensor) -> str: dev = a.device return self._dev2str(dev)
[docs] def device_move(self, a: Tensor, dev: Any) -> Tensor: if isinstance(dev, str): dev = self._str2dev(dev) with tf.device(dev): a = tf.identity(a) return a
def _dev2str(self, dev: Any) -> str: platform, id_ = dev.split(":")[-2:] if platform == "CPU": return "cpu" if platform == "GPU": return "gpu" + ":" + str(id_) raise ValueError("TensorFlowBackend don't support non-GPU/CPU device") def _str2dev(self, str_: str) -> Any: return str_
[docs] def stop_gradient(self, a: Tensor) -> Tensor: return tf.stop_gradient(a)
[docs] def grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Any]: # experimental attempt # Note: tensorflow grad is gradient while jax grad is derivative, they are different with a conjugate! # And we DONT make them consitent by mannually set conjugate of the returns. def wrapper(*args: Any, **kws: Any) -> Any: with tf.GradientTape() as t: if isinstance(argnums, int): x = args[argnums] else: args = tuple( [ tf.identity(arg) if i in argnums else arg for i, arg in enumerate(args) ] ) # in case wrong grad for f(x, x) x = [args[i] for i in argnums] t.watch(x) y = f(*args, **kws) if has_aux: g = t.gradient(y[0], x) else: g = t.gradient(y, x) if has_aux: return (g, y[1:]) return g return wrapper
[docs] def value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: def wrapper(*args: Any, **kws: Any) -> Any: with tf.GradientTape() as t: if isinstance(argnums, int): x = args[argnums] else: args = tuple( [ tf.identity(arg) if i in argnums else arg for i, arg in enumerate(args) ] ) x = [args[i] for i in argnums] t.watch(x) y = f(*args, **kws) if has_aux: g = t.gradient(y[0], x) else: g = t.gradient(y, x) return y, g return wrapper
[docs] def jvp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if not (isinstance(inputs, list) or isinstance(inputs, tuple)): # one input tensor inputs = [inputs] elif isinstance(inputs, list): inputs = [tf.identity(inp) for inp in inputs] else: # inputs, tuple inputs = tuple([tf.identity(inp) for inp in inputs]) if not (isinstance(v, list) or isinstance(v, tuple)): v = [v] elif isinstance(v, list): v = [tf.identity(vi) for vi in v] else: v = tuple([tf.identity(vi) for vi in v]) with tf.autodiff.ForwardAccumulator(inputs, v) as t: y = f(*inputs) g = t.jvp(y) return y, g
[docs] def vjp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if not (isinstance(inputs, list) or isinstance(inputs, tuple)): # one input tensor one_input = True inputs = [inputs] elif isinstance(inputs, list): inputs = [tf.identity(inp) for inp in inputs] one_input = False else: # inputs tuple inputs = tuple([tf.identity(inp) for inp in inputs]) one_input = False with tf.GradientTape() as t: t.watch(inputs) y = f(*inputs) g = t.gradient(y, inputs, v) g = list(g) for i, gi in enumerate(g): if gi is None: g[i] = tf.zeros_like(inputs[i]) if isinstance(gi, tf.IndexedSlices): # gradient can return sth weird # TODO(@refraction-ray): check whether other AD tf methods have such issues # shape is still unkown, dense_shape attr doesn't work? g[i] = tf.convert_to_tensor(gi) g = tuple(g) if one_input: g = g[0] return y, g
[docs] def jit( self, f: Callable[..., Any], static_argnums: Optional[Union[int, Sequence[int]]] = None, jit_compile: Optional[bool] = None, **kws: Any ) -> Any: # static_argnums not supported in tf case, this is only for a consistent interface # for more on static_argnums in tf.function, see issue: https://github.com/tensorflow/tensorflow/issues/52193 # tf.function works with dict pytree but fails at list pytree, hmm... # no full jittable pytree support in tf # another difference from jax.jit if self.minor < 5: return tf.function(f, experimental_compile=jit_compile) else: return tf.function(f, jit_compile=jit_compile)
# old vmap impl before I know pytrees support in tf # def vmap( # self, f: Callable[..., Any], vectorized_argnums: Union[int, Sequence[int]] = 0 # ) -> Any: # if isinstance(vectorized_argnums, int): # vectorized_argnums = (vectorized_argnums,) # if vectorized_argnums == (0,): # fast shortcut # def wrapper(*args: Any, **kws: Any) -> Tensor: # def pf(x: Tensor) -> Tensor: # return f(x, *args[1:], **kws) # return tf.vectorized_map(pf, args[0]) # else: # @self.jit # def wrapper(*args: Any, **kws: Any) -> Tensor: # shapes = [] # l = len(args) # seps = [0] # batch = args[vectorized_argnums[0]].shape[0] # type: ignore # nargs = [] # for i, arg in enumerate(args): # if i in vectorized_argnums: # type: ignore # shapes.append(arg.shape[1:]) # assert ( # arg.shape[0] == batch # ), "different tensors has different batch dimensions!" # arg = tf.reshape(arg, [batch, -1]) # nargs.append(arg) # seps.append(seps[-1] + arg.shape[-1]) # sargs = tf.concat(nargs, 1) # def sf(sarg: Any) -> Any: # vargs = [] # for i in range(len(shapes)): # arg = sarg[seps[i] : seps[i + 1]] # arg = tf.reshape(arg, shapes[i]) # vargs.append(arg) # vvargs = [] # j = 0 # for i in range(l): # if i in vectorized_argnums: # type: ignore # vvargs.append(vargs[j]) # j += 1 # else: # vvargs.append(args[i]) # return f(*vvargs, **kws) # return tf.vectorized_map(sf, sargs) # return wrapper # def wrapper(f: Callable[..., Any], args: Sequence[Any]) -> Any: # return f(*args) # wrapper = partial(wrapper, f) # def own_vectorized_map(f: Callable[..., Any], *args: Any) -> Any: # return tf.vectorized_map(f, args) # return partial(own_vectorized_map, wrapper)
[docs] def vmap( self, f: Callable[..., Any], vectorized_argnums: Union[int, Sequence[int]] = 0 ) -> Any: if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) if vectorized_argnums == (0,): # fast shortcut def wrapper(*args: Any, **kws: Any) -> Tensor: def pf(x: Tensor) -> Tensor: return f(x, *args[1:], **kws) return tf.vectorized_map(pf, args[0]) else: # @self.jit # otherwise, vectorized_map claim on retracing def wrapper(*args: Any, **kws: Any) -> Tensor: # @self.jit def sf(sarg: Any) -> Any: vvargs = [] j = 0 for i in range(len(args)): if i in vectorized_argnums: # type: ignore vvargs.append(sarg[j]) j += 1 else: vvargs.append(args[i]) return f(*vvargs, **kws) sarg = [] for i in vectorized_argnums: # type: ignore sarg.append(args[i]) return tf.vectorized_map(sf, sarg) return wrapper
[docs] def vectorized_value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, vectorized_argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: # note how tf only works in this order, due to the bug reported as: # https://github.com/google/TensorNetwork/issues/940 vf = self.vmap(f, vectorized_argnums=vectorized_argnums) def wrapper( *args: Any, **kws: Any ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: with tf.GradientTape() as tape: if isinstance(argnums, int): x = args[argnums] else: x = [args[i] for i in argnums] tape.watch(x) vs = vf(*args, **kws) if has_aux: grad = tape.gradient(vs[0], x) else: grad = tape.gradient(vs, x) return vs, grad return wrapper
# f = self.vmap(f) # f = self.value_and_grad(f, argnums=argnums) # f = self.jit(f) # return f vvag = vectorized_value_and_grad optimizer = keras_optimizer