"""
Interface wraps quantum function as a tensorflow function
"""
from typing import Any, Callable, Tuple
from functools import wraps
from ..cons import backend
from ..utils import return_partial
from .tensortrans import general_args_to_backend
Tensor = Any
[文档]def tf_wrapper(
fun: Callable[..., Any], enable_dlpack: bool = False
) -> Callable[..., Any]:
@wraps(fun)
def fun_tf(*x: Any) -> Any:
x = general_args_to_backend(x, enable_dlpack=enable_dlpack)
y = fun(*x)
y = general_args_to_backend(
y, target_backend="tensorflow", enable_dlpack=enable_dlpack
)
return y
return fun_tf
[文档]def tf_dtype(dtype: str) -> Any:
import tensorflow as tf
if isinstance(dtype, str):
return getattr(tf, dtype)
return dtype
[文档]def tensorflow_interface(
fun: Callable[..., Any], ydtype: Any, jit: bool = False, enable_dlpack: bool = False
) -> Callable[..., Any]:
"""
Wrap a quantum function on different ML backend with a tensorflow interface.
:Example:
.. code-block:: python
K = tc.set_backend("jax")
def f(params):
c = tc.Circuit(1)
c.rx(0, theta=params[0])
c.ry(0, theta=params[1])
return K.real(c.expectation([tc.gates.z(), [0]]))
f = tc.interfaces.tf_interface(f, ydtype=tf.float32, jit=True)
tfb = tc.get_backend("tensorflow")
grads = tfb.jit(tfb.grad(f))(tfb.ones([2]))
:param fun: The quantum function with tensor in and tensor out
:type fun: Callable[..., Any]
:param ydtype: output tf dtype or in str
:type ydtype: Any
:param jit: whether to jit ``fun``, defaults to False
:type jit: bool, optional
:param enable_dlpack: whether transform tensor backend via dlpack, defaults to False
:type enable_dlpack: bool, optional
:return: The same quantum function but now with torch tensor in and torch tensor out
while AD is also supported
:rtype: Callable[..., Any]
"""
import tensorflow as tf
if jit is True:
fun = backend.jit(fun)
fun_tf = tf_wrapper(fun)
ydtype = backend.tree_map(tf_dtype, ydtype)
@tf.custom_gradient # type: ignore
def fun_wrap(*x: Any) -> Any:
nx = len(x)
def vjp_fun(*xv: Tensor) -> Tuple[Tensor, Tensor]:
x = xv[:nx]
v = xv[nx:]
if len(x) == 1:
x = x[0]
if len(v) == 1:
v = v[0]
return backend.vjp(fun, x, v)
if jit is True:
vjp_fun = backend.jit(vjp_fun)
vjp_fun_tf = return_partial(tf_wrapper(vjp_fun), 1)
xdtype = backend.tree_map(lambda x: x.dtype, x)
# (x, )
if len(xdtype) == 1:
xdtype = xdtype[0]
y = tf.py_function(func=fun_tf, inp=x, Tout=ydtype)
# if len(x) == 1:
# x = x[0]
def grad(*dy: Any, **kws: Any) -> Any:
# if len(dy) == 1:
# dy = dy[0]
# g = vjp_fun_tf(*(x+dy))
g = tf.py_function(func=vjp_fun_tf, inp=x + dy, Tout=xdtype)
# a redundency due to current vjp API
return g
return y, grad
return fun_wrap # type: ignore
tf_interface = tensorflow_interface
# TODO(@refraction-ray): overhead and efficiency to be benchmarked