"""
Interface wraps quantum function as a torch function
"""
from typing import Any, Callable, Dict, Tuple
from functools import partial
from ..cons import backend
from ..utils import is_sequence
from .tensortrans import general_args_to_backend
Tensor = Any
# TODO(@refraction-ray): new paradigm compatible with torch functional trasnformation
[文档]def torch_interface(
fun: Callable[..., Any], jit: bool = False, enable_dlpack: bool = False
) -> Callable[..., Any]:
"""
Wrap a quantum function on different ML backend with a pytorch interface.
:Example:
.. code-block:: python
import torch
tc.set_backend("tensorflow")
def f(params):
c = tc.Circuit(1)
c.rx(0, theta=params[0])
c.ry(0, theta=params[1])
return c.expectation([tc.gates.z(), [0]])
f_torch = tc.interfaces.torch_interface(f, jit=True)
a = torch.ones([2], requires_grad=True)
b = f_torch(a)
c = b ** 2
c.backward()
print(a.grad)
:param fun: The quantum function with tensor in and tensor out
:type fun: Callable[..., Any]
:param jit: whether to jit ``fun``, defaults to False
:type jit: bool, optional
:param enable_dlpack: whether transform tensor backend via dlpack, defaults to False
:type enable_dlpack: bool, optional
:return: The same quantum function but now with torch tensor in and torch tensor out
while AD is also supported
:rtype: Callable[..., Any]
"""
import torch
def vjp_fun(x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
return backend.vjp(fun, x, v)
if jit is True:
fun = backend.jit(fun)
vjp_fun = backend.jit(vjp_fun)
class Fun(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, *x: Any) -> Any: # type: ignore
# ctx.xdtype = [xi.dtype for xi in x]
ctx.xdtype = backend.tree_map(lambda s: s.dtype, x)
# (x, )
if len(ctx.xdtype) == 1:
ctx.xdtype = ctx.xdtype[0]
ctx.device = (backend.tree_flatten(x)[0][0]).device
x = general_args_to_backend(x, enable_dlpack=enable_dlpack)
y = fun(*x)
ctx.ydtype = backend.tree_map(lambda s: s.dtype, y)
if len(x) == 1:
x = x[0]
ctx.x = x
y = general_args_to_backend(
y, target_backend="pytorch", enable_dlpack=enable_dlpack
)
y = backend.tree_map(lambda s: s.to(device=ctx.device), y)
return y
@staticmethod
def backward(ctx: Any, *grad_y: Any) -> Any:
if len(grad_y) == 1:
grad_y = grad_y[0]
grad_y = backend.tree_map(lambda s: s.contiguous(), grad_y)
grad_y = general_args_to_backend(
grad_y, dtype=ctx.ydtype, enable_dlpack=enable_dlpack
)
# grad_y = general_args_to_numpy(grad_y)
# grad_y = numpy_args_to_backend(grad_y, dtype=ctx.ydtype) # backend.dtype
_, g = vjp_fun(ctx.x, grad_y)
# a redundency due to current vjp API
r = general_args_to_backend(
g,
dtype=ctx.xdtype,
target_backend="pytorch",
enable_dlpack=enable_dlpack,
)
r = backend.tree_map(lambda s: s.to(device=ctx.device), r)
if not is_sequence(r):
return (r,)
return r
# currently, memory transparent dlpack in these ML framework has broken support on complex dtypes
return Fun.apply
pytorch_interface = torch_interface
[文档]def torch_interface_kws(
f: Callable[..., Any], jit: bool = True, enable_dlpack: bool = False
) -> Callable[..., Any]:
"""
similar to py:meth:`tensorcircuit.interfaces.torch.torch_interface`,
but now the interface support static arguments for function ``f``,
which is not a tensor and can be used with keyword arguments
:Example:
.. code-block:: python
tc.set_backend("tensorflow")
def f(tensor, integer):
r = 0.
for i in range(integer):
r += tensor
return r
fnew = tc.interfaces.torch_interface_kws(f)
print(fnew(torch.ones([2]), integer=3))
print(fnew(torch.ones([2]), integer=4))
:param f: _description_
:type f: Callable[..., Any]
:param jit: _description_, defaults to True
:type jit: bool, optional
:param enable_dlpack: _description_, defaults to False
:type enable_dlpack: bool, optional
:return: _description_
:rtype: Callable[..., Any]
"""
cache_dict: Dict[Tuple[Any, ...], Callable[..., Any]] = {}
def wrapper(*args: Any, **kws: Any) -> Any:
key = tuple([(k, v) for k, v in kws.items()])
if key not in cache_dict:
fnew = torch_interface(
partial(f, **kws), jit=jit, enable_dlpack=enable_dlpack
)
cache_dict[key] = fnew
return cache_dict[key](*args)
return wrapper