tensorcircuit.backends.pytorch_backend 源代码

"""
Backend magic inherited from tensornetwork: pytorch backend
"""
# pylint: disable=invalid-name

import logging
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from operator import mul
from functools import reduce, partial

import tensornetwork
from tensornetwork.backends.pytorch import pytorch_backend
from .abstract_backend import ExtendedBackend

dtypestr: str
Tensor = Any
pytree = Any

torchlib: Any

logger = logging.getLogger(__name__)

# TODO(@refraction-ray): lack stateful random methods implementation for now
# TODO(@refraction-ray): lack scatter impl for now
# TODO(@refraction-ray): lack sparse relevant methods for now
# To be added once pytorch backend is ready


[文档]class torch_jit_func: """ Delay the tracing of torch jit to the first run time: consistent with tf and jax mechanism """
[文档] def __init__(self, f: Callable[..., Any]): self.compiled = False self.f = f
def __call__(self, *args: Any, **kws: Any) -> Any: if self.compiled is False: self.f = torchlib.jit.trace(self.f, example_inputs=args) self.compiled = True return self.f(*args, **kws)
[文档]class torch_optimizer:
[文档] def __init__(self, optimizer: Any) -> None: self.optimizer = optimizer self.is_init = False
[文档] def update(self, grads: pytree, params: pytree) -> pytree: # flatten grad and param params, treedef = PyTorchBackend.tree_flatten(None, params) grads, _ = PyTorchBackend.tree_flatten(None, grads) if self.is_init is False: self.optimizer = self.optimizer(params) self.is_init = True with torchlib.no_grad(): for g, p in zip(grads, params): p.grad = g self.optimizer.step() self.optimizer.zero_grad() # reorg the param params = PyTorchBackend.tree_unflatten(None, treedef, params) return params
def _conj_torch(self: Any, tensor: Tensor) -> Tensor: t = torchlib.conj(tensor) return t.resolve_conj() # any side effect? def _sum_torch( self: Any, tensor: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False, ) -> Tensor: if axis is None: axis = tuple([i for i in range(len(tensor.shape))]) return torchlib.sum(tensor, dim=axis, keepdim=keepdims) def _qr_torch( self: Any, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the QR decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. :Example: If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `q` would have shape (2, 3, 6), and `r` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] Note that the output ordering matches numpy.linalg.svd rather than tf.svd. :param tensor: A tensor to be decomposed. :type tensor: Tensor :param pivot_axis: Where to split the tensor's axes before flattening into a matrix. :type pivot_axis: int, optional :param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix. :type non_negative_diagonal: bool, optional :returns: Q, the left tensor factor, and R, the right tensor factor. :rtype: Tuple[Tensor, Tensor] """ from .pytorch_ops import torchqr left_dims = list(tensor.shape[:pivot_axis]) right_dims = list(tensor.shape[pivot_axis:]) tensor = torchlib.reshape(tensor, [reduce(mul, left_dims), reduce(mul, right_dims)]) q, r = torchqr.apply(tensor) if non_negative_diagonal: phases = torchlib.sign(torchlib.linalg.diagonal(r)) q = q * phases r = phases[:, None] * r center_dim = q.shape[1] q = torchlib.reshape(q, left_dims + [center_dim]) r = torchlib.reshape(r, [center_dim] + right_dims) return q, r def _rq_torch( self: Any, tensor: Tensor, pivot_axis: int = 1, non_negative_diagonal: bool = False, ) -> Tuple[Tensor, Tensor]: """ Computes the RQ decomposition of a tensor. The QR decomposition is performed by treating the tensor as a matrix, with an effective left (row) index resulting from combining the axes `tensor.shape[:pivot_axis]` and an effective right (column) index resulting from combining the axes `tensor.shape[pivot_axis:]`. :Example: If `tensor` had a shape (2, 3, 4, 5) and `pivot_axis` was 2, then `r` would have shape (2, 3, 6), and `q` would have shape (6, 4, 5). The output consists of two tensors `Q, R` such that: Q[i1,...,iN, j] * R[j, k1,...,kM] == tensor[i1,...,iN, k1,...,kM] Note that the output ordering matches numpy.linalg.svd rather than tf.svd. :param tensor: A tensor to be decomposed. :type tensor: Tensor :param pivot_axis: Where to split the tensor's axes before flattening into a matrix. :type pivot_axis: int, optional :param non_negative_diagonal: a bool indicating whether the tenor is diagonal non-negative matrix. :type non_negative_diagonal: bool, optional :returns: Q, the left tensor factor, and R, the right tensor factor. :rtype: Tuple[Tensor, Tensor] """ from .pytorch_ops import torchqr left_dims = list(tensor.shape[:pivot_axis]) right_dims = list(tensor.shape[pivot_axis:]) tensor = torchlib.reshape(tensor, [reduce(mul, left_dims), reduce(mul, right_dims)]) q, r = torchqr.apply(tensor.adjoint()) if non_negative_diagonal: phases = torchlib.sign(torchlib.linalg.diagonal(r)) q = q * phases r = phases[:, None] * r r, q = r.adjoint(), q.adjoint() # M=r*q at this point center_dim = r.shape[1] r = torchlib.reshape(r, left_dims + [center_dim]) q = torchlib.reshape(q, [center_dim] + right_dims) return r, q tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.sum = _sum_torch tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.conj = _conj_torch tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.qr = _qr_torch tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.rq = _rq_torch
[文档]class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type: ignore """ See the original backend API at `pytorch backend <https://github.com/google/TensorNetwork/blob/master/tensornetwork/backends/pytorch/pytorch_backend.py>`_ Note the functionality provided by pytorch backend is incomplete, it currenly lacks native efficicent jit and vmap support. """
[文档] def __init__(self) -> None: super(PyTorchBackend, self).__init__() global torchlib try: import torch except ImportError: raise ImportError( "PyTorch not installed, please switch to a different " "backend or install PyTorch." ) torchlib = torch self.name = "pytorch"
[文档] def eye( self, N: int, dtype: Optional[str] = None, M: Optional[int] = None ) -> Tensor: if dtype is None: dtype = dtypestr if not M: M = N r = torchlib.eye(n=N, m=M) return self.cast(r, dtype)
[文档] def ones(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = torchlib.ones(shape) return self.cast(r, dtype)
[文档] def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: if dtype is None: dtype = dtypestr r = torchlib.zeros(shape) return self.cast(r, dtype)
[文档] def copy(self, a: Tensor) -> Tensor: return a.clone()
[文档] def expm(self, a: Tensor) -> Tensor: raise NotImplementedError("pytorch backend doesn't support expm")
# in 2020, torch has no expm, hmmm. but that's ok, # it doesn't support complex numbers which is more severe issue. # see https://github.com/pytorch/pytorch/issues/9983
[文档] def sin(self, a: Tensor) -> Tensor: return torchlib.sin(a)
[文档] def cos(self, a: Tensor) -> Tensor: return torchlib.cos(a)
[文档] def acos(self, a: Tensor) -> Tensor: return torchlib.acos(a)
[文档] def acosh(self, a: Tensor) -> Tensor: return torchlib.acosh(a)
[文档] def asin(self, a: Tensor) -> Tensor: return torchlib.asin(a)
[文档] def asinh(self, a: Tensor) -> Tensor: return torchlib.asinh(a)
[文档] def atan(self, a: Tensor) -> Tensor: return torchlib.atan(a)
[文档] def atan2(self, y: Tensor, x: Tensor) -> Tensor: return torchlib.atan2(y, x)
[文档] def atanh(self, a: Tensor) -> Tensor: return torchlib.atanh(a)
[文档] def cosh(self, a: Tensor) -> Tensor: return torchlib.cosh(a)
[文档] def tan(self, a: Tensor) -> Tensor: return torchlib.tan(a)
[文档] def tanh(self, a: Tensor) -> Tensor: return torchlib.tanh(a)
[文档] def sinh(self, a: Tensor) -> Tensor: return torchlib.sinh(a)
[文档] def size(self, a: Tensor) -> Tensor: return a.size()
[文档] def eigvalsh(self, a: Tensor) -> Tensor: return torchlib.linalg.eigvalsh(a)
[文档] def kron(self, a: Tensor, b: Tensor) -> Tensor: return torchlib.kron(a, b)
[文档] def numpy(self, a: Tensor) -> Tensor: a = a.cpu() if a.is_conj(): return a.resolve_conj().numpy() if a.requires_grad: return a.detach().numpy() return a.numpy()
[文档] def i(self, dtype: Any = None) -> Tensor: if not dtype: dtype = getattr(torchlib, dtypestr) if isinstance(dtype, str): dtype = getattr(torchlib, dtype) return torchlib.tensor(1j, dtype=dtype)
[文档] def det(self, a: Tensor) -> Tensor: return torchlib.linalg.det(a)
[文档] def real(self, a: Tensor) -> Tensor: try: a = torchlib.real(a) except RuntimeError: pass return a
[文档] def imag(self, a: Tensor) -> Tensor: try: a = torchlib.imag(a) except RuntimeError: pass return a
[文档] def dtype(self, a: Tensor) -> str: return a.dtype.__str__().split(".")[-1] # type: ignore
[文档] def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return torchlib.stack(a, dim=axis)
[文档] def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return torchlib.cat(a, dim=axis)
[文档] def tile(self, a: Tensor, rep: Tensor) -> Tensor: return torchlib.tile(a, rep)
[文档] def mean( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False, ) -> Tensor: if axis is None: axis = tuple([i for i in range(len(a.shape))]) return torchlib.mean(a, dim=axis, keepdim=keepdims)
[文档] def std( self, a: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False ) -> Tensor: if axis is None: axis = tuple([i for i in range(len(a.shape))]) return torchlib.std(a, dim=axis, unbiased=False, keepdim=keepdims)
[文档] def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: return torchlib.min(a) return torchlib.min(a, dim=axis).values
[文档] def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: return torchlib.max(a) return torchlib.max(a, dim=axis).values
[文档] def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return torchlib.argmax(a, dim=axis)
[文档] def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return torchlib.argmin(a, dim=axis)
[文档] def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: return torchlib.unique(a, return_counts=True) # type: ignore
[文档] def sigmoid(self, a: Tensor) -> Tensor: return torchlib.sigmoid(a)
[文档] def relu(self, a: Tensor) -> Tensor: return torchlib.relu(a)
[文档] def softmax(self, a: Sequence[Tensor], axis: Optional[int] = None) -> Tensor: return torchlib.nn.Softmax(a, dim=axis)
[文档] def onehot(self, a: Tensor, num: int) -> Tensor: return torchlib.nn.functional.one_hot(a, num)
[文档] def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: if axis is None: a = self.reshape(a, [-1]) return torchlib.cumsum(a, dim=0) else: return torchlib.cumsum(a, dim=axis)
[文档] def is_tensor(self, a: Any) -> bool: if isinstance(a, torchlib.Tensor): return True return False
[文档] def cast(self, a: Tensor, dtype: str) -> Tensor: if isinstance(dtype, str): return a.type(getattr(torchlib, dtype)) return a.type(dtype)
[文档] def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tensor: if stop is None: return torchlib.arange(start=0, end=start, step=step) return torchlib.arange(start=start, end=stop, step=step)
[文档] def mod(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.fmod(x, y)
[文档] def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_right_shift(x, y)
[文档] def left_shift(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_left_shift(x, y)
[文档] def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor: return torchlib.linalg.solve(A, b)
[文档] def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: if not self.is_tensor(a): a = self.convert_to_tensor(a) if not self.is_tensor(v): v = self.convert_to_tensor(v) return torchlib.searchsorted(a, v, side=side)
[文档] def reverse(self, a: Tensor) -> Tensor: return torchlib.flip(a, dims=(-1,))
[文档] def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any: # torch native tree_map not support multiple pytree args # return torchlib.utils._pytree.tree_map(f, *pytrees) args = [] for pytree in pytrees: flat_args, spec = self.tree_flatten(pytree) args.append(flat_args) res = [ f(*[args[i][k] for i in range(len(pytrees))]) for k in range(len(flat_args)) ] return self.tree_unflatten(spec, res)
[文档] def tree_flatten(self: Any, pytree: Any) -> Tuple[Any, Any]: return torchlib.utils._pytree.tree_flatten(pytree) # type: ignore
[文档] def tree_unflatten(self: Any, treedef: Any, leaves: Any) -> Any: return torchlib.utils._pytree.tree_unflatten(leaves, treedef)
[文档] def from_dlpack(self, a: Any) -> Tensor: return torchlib.utils.dlpack.from_dlpack(a)
[文档] def to_dlpack(self, a: Tensor) -> Any: return torchlib.utils.dlpack.to_dlpack(a)
[文档] def cond( self, pred: bool, true_fun: Callable[[], Tensor], false_fun: Callable[[], Tensor], ) -> Tensor: if pred: return true_fun() return false_fun()
[文档] def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Tensor: return branches[index.numpy()]()
[文档] def device(self, a: Tensor) -> str: dev = a.device return self._dev2str(dev)
[文档] def device_move(self, a: Tensor, dev: Any) -> Tensor: if not isinstance(dev, str): dev = self._dev2str(dev) if dev.startswith("gpu"): dev = "cuda:" + dev.split(":")[-1] return a.to(device=dev)
def _dev2str(self, dev: Any) -> str: if dev.type == "cpu": return "cpu" if dev.type == "cuda": return "gpu:" + str(dev.index) raise ValueError("PyTorchBackend don't support non-GPU/CPU device") def _str2dev(self, str_: str) -> Any: if str_ == "cpu": return torchlib.device("cpu") if str_.startswith("gpu"): _id = int(str_.split(":")[-1]) return torchlib.cuda.device(_id) raise ValueError("PyTorchBackend don't support non-GPU/CPU device")
[文档] def stop_gradient(self, a: Tensor) -> Tensor: return a.detach()
[文档] def grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Any]: def wrapper(*args: Any, **kws: Any) -> Any: y, gr = self.value_and_grad(f, argnums, has_aux)(*args, **kws) if has_aux: return gr, y[1:] return gr return wrapper
# def wrapper(*args: Any, **kws: Any) -> Any: # x = [] # if isinstance(argnums, int): # argnumsl = [argnums] # # if you also call lhs as argnums, something weird may happen # # the reason is that python then take it as local vars # else: # argnumsl = argnums # type: ignore # for i, arg in enumerate(args): # if i in argnumsl: # x.append(arg.requires_grad_(True)) # else: # x.append(arg) # y = f(*x, **kws) # y.backward() # gs = [x[i].grad for i in argnumsl] # if len(gs) == 1: # gs = gs[0] # return gs # return wrapper
[文档] def value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: def wrapper(*args: Any, **kws: Any) -> Any: gavf = torchlib.func.grad_and_value(f, argnums=argnums, has_aux=has_aux) g, v = gavf(*args, **kws) return v, g return wrapper
# def ask_require(t: Tensor) -> Any: # t.requires_grad_(True) # return t # def get_grad(t: Tensor) -> Tensor: # return t.grad # def wrapper(*args: Any, **kws: Any) -> Any: # # x = [] # if isinstance(argnums, int): # argnumsl = [argnums] # # if you also call lhs as argnums, something weird may happen # # the reason is that python then take it as local vars # else: # argnumsl = argnums # type: ignore # args = list(args) # for i, arg in enumerate(args): # if i in argnumsl: # args[i] = self.tree_map(ask_require, arg) # args = tuple(args) # y = f(*args, **kws) # if has_aux: # y[0].backward() # else: # y.backward() # gs = [self.tree_map(get_grad, x[i]) for i in argnumsl] # if len(gs) == 1: # gs = gs[0] # return y, gs # return wrapper
[文档] def vjp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if isinstance(inputs, list): inputs = tuple(inputs) if isinstance(v, list): v = tuple(v) return torchlib.autograd.functional.vjp(f, inputs, v) # type: ignore
[文档] def jvp( self, f: Callable[..., Any], inputs: Union[Tensor, Sequence[Tensor]], v: Union[Tensor, Sequence[Tensor]], ) -> Tuple[Union[Tensor, Sequence[Tensor]], Union[Tensor, Sequence[Tensor]]]: if isinstance(inputs, list): inputs = tuple(inputs) if isinstance(v, list): v = tuple(v) # for both tf and torch # behind the scene: https://j-towns.github.io/2017/06/12/A-new-trick.html # to be investigate whether the overhead issue remains as in # https://github.com/renmengye/tensorflow-forward-ad/issues/2 return torchlib.autograd.functional.jvp(f, inputs, v) # type: ignore
[文档] def vmap( self, f: Callable[..., Any], vectorized_argnums: Union[int, Sequence[int]] = 0, ) -> Any: if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) def wrapper(*args: Any, **kws: Any) -> Tensor: in_axes = tuple([0 if i in vectorized_argnums else None for i in range(len(args))]) # type: ignore return torchlib.vmap(f, in_axes, 0)(*args, **kws) return wrapper
# v3 # logger.warning( # "pytorch backend has no intrinsic vmap like interface" # ", use plain for loop for compatibility" # ) # # the vmap support is vey limited, f must return one tensor # # nested list of tensor as return is not supported # if isinstance(vectorized_argnums, int): # vectorized_argnums = (vectorized_argnums,) # def wrapper(*args: Any, **kws: Any) -> Tensor: # results = [] # for barg in zip(*[args[i] for i in vectorized_argnums]): # type: ignore # narg = [] # j = 0 # for k in range(len(args)): # if k in vectorized_argnums: # type: ignore # narg.append(barg[j]) # j += 1 # else: # narg.append(args[k]) # results.append(f(*narg, **kws)) # return torchlib.stack(results) # return wrapper # v2 # def vmapf(*args: Tensor, **kws: Any) -> Tensor: # r = [] # for i in range(args[0].shape[0]): # nargs = [arg[i] for arg in args] # r.append(f(*nargs, **kws)) # return torchlib.stack(r) # return vmapf # v1 # raise NotImplementedError("pytorch backend doesn't support vmap") # There seems to be no map like architecture in pytorch for now # see https://discuss.pytorch.org/t/fast-way-to-use-map-in-pytorch/70814
[文档] def jit( self, f: Callable[..., Any], static_argnums: Optional[Union[int, Sequence[int]]] = None, jit_compile: Optional[bool] = None, **kws: Any ) -> Any: if jit_compile is True: # experimental feature reusing the jit_compile flag for tf return torch_jit_func(f) return f
# return f # do nothing here until I figure out what torch.jit is for and how does it work # see https://github.com/pytorch/pytorch/issues/36910
[文档] def vectorized_value_and_grad( self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0, vectorized_argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, ) -> Callable[..., Tuple[Any, Any]]: # [WIP], not a consistent impl compared to tf and jax backend, but pytorch backend is not fully supported anyway if isinstance(vectorized_argnums, int): vectorized_argnums = (vectorized_argnums,) def wrapper( *args: Any, **kws: Any ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: jf = self.value_and_grad(f, argnums=argnums, has_aux=has_aux) jf = self.vmap(jf, vectorized_argnums=vectorized_argnums) vs, gs = jf(*args, **kws) if isinstance(argnums, int): argnums_list = [argnums] gs = [gs] else: argnums_list = argnums # type: ignore gs = list(gs) for i, (j, g) in enumerate(zip(argnums_list, gs)): if j not in vectorized_argnums: # type: ignore gs[i] = self.tree_map(partial(torchlib.sum, dim=0), g) if isinstance(argnums, int): gs = gs[0] else: gs = tuple(gs) return vs, gs return wrapper
vvag = vectorized_value_and_grad optimizer = torch_optimizer