TensorCircuit: 常见错误 🔪#
虽然在TensorCircuit中速度很快,但是你必须小心,尤其是在AD和JIT兼容性方面。
Jit 兼容性#
非向量输入或者变化形状的向量输入#
输入必须是张量形式,且输入张量的形状必须固定,否则会重新编译,这是非常耗时的。因此,如果有输入参数是非张量或者变化形状的张量,且经常变化,不建议使用jit。
K = tc.set_backend("tensorflow")
@K.jit
def f(a):
print("compiling")
return 2*a
f(K.ones([2]))
# compiling
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([2.+0.j, 2.+0.j], dtype=complex64)>
f(K.zeros([2]))
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.+0.j, 0.+0.j], dtype=complex64)>
f(K.ones([3]))
# compiling
# <tf.Tensor: shape=(3,), dtype=complex64, numpy=array([2.+0.j, 2.+0.j, 2.+0.j], dtype=complex64)>
混合使用numpy和ML后端API#
为了使函数可jit和可AD,函数中的每个操作都应该通过ML后端(tc.backend
API或者直接调用后端API tf
或者 jax
)。这是因为ML后端必须创建计算图来"进行AD和JIT转换。对于numpy操作,它们只会在jit编译阶段被调用(第一次运行)。
K = tc.set_backend("tensorflow")
@K.jit
def f(a):
return np.dot(a, a)
f(K.ones([2]))
# NotImplementedError: Cannot convert a symbolic Tensor (a:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported
如果你确定numpy调用的行为是你期望的,那么在jit函数中调用numpy是有帮助的。
K = tc.set_backend("tensorflow")
@K.jit
def f(a):
print("compiling")
n = a.shape[0]
m = int(np.log(n)/np.log(2))
return K.reshape(a, [2 for _ in range(m)])
f(K.ones([4]))
# compiling
# <tf.Tensor: shape=(2, 2), dtype=complex64, numpy=
# array([[1.+0.j, 1.+0.j],
# [1.+0.j, 1.+0.j]], dtype=complex64)>
f(K.zeros([4]))
# <tf.Tensor: shape=(2, 2), dtype=complex64, numpy=
# array([[0.+0.j, 0.+0.j],
# [0.+0.j, 0.+0.j]], dtype=complex64)>
f(K.zeros([2]))
# compiling
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.+0.j, 0.+0.j], dtype=complex64)>
if下的list append#
在if条件基于张量值的情况下,将内容附加到Python列表中会导致错误的结果。实际上,两个分支的值都会被附加到列表中。参见下面的例子。
K = tc.set_backend("tensorflow")
@K.jit
def f(a):
l = []
one = K.ones([])
zero = K.zeros([])
if a > 0:
l.append(one)
else:
l.append(zero)
return l
f(-K.ones([], dtype="float32"))
# [<tf.Tensor: shape=(), dtype=complex64, numpy=(1+0j)>,
# <tf.Tensor: shape=(), dtype=complex64, numpy=0j>]
上面的代码直接为Jax后端引发了``ConcretizationTypeError``异常,因为Jax jit不支持张量值if条件。
类似地,必须小心地应用条件门。
K = tc.set_backend("tensorflow")
@K.jit
def f():
c = tc.Circuit(1)
c.h(0)
a = c.cond_measure(0)
if a > 0.5:
c.x(0)
else:
c.z(0)
return c.state()
f()
# InaccessibleTensorError: tf.Graph captured an external symbolic tensor.
# The correct implementation is
@K.jit
def f():
c = tc.Circuit(1)
c.h(0)
a = c.cond_measure(0)
c.conditional_gate(a, [tc.gates.z(), tc.gates.x()], 0)
return c.state()
f()
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.99999994+0.j, 0. +0.j], dtype=complex64)>
Tensor variables consistency#
All tensor variables' backend (tf vs jax vs ..), dtype (float vs complex), shape and device (cpu vs gpu) must be compatible/consistent.
Inspect the backend, dtype, shape and device using the following codes.
for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
with tc.runtime_backend(backend):
a = tc.backend.ones([2, 3])
print("tensor backend:", tc.interfaces.which_backend(a))
print("tensor dtype:", tc.backend.dtype(a))
print("tensor shape:", tc.backend.shape_tuple(a))
print("tensor device:", tc.backend.device(a))
If the backend is inconsistent, one can convert the tensor backend via tensorcircuit.interfaces.tensortrans.general_args_to_backend()
.
for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
with tc.runtime_backend(backend):
a = tc.backend.ones([2, 3])
print("tensor backend:", tc.interfaces.which_backend(a))
b = tc.interfaces.general_args_to_backend(a, target_backend="jax", enable_dlpack=False)
print("tensor backend:", tc.interfaces.which_backend(b))
If the dtype is inconsistent, one can convert the tensor dtype using tc.backend.cast
.
for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
with tc.runtime_backend(backend):
a = tc.backend.ones([2, 3])
print("tensor dtype:", tc.backend.dtype(a))
b = tc.backend.cast(a, dtype="float64")
print("tensor dtype:", tc.backend.dtype(b))
Also note the jax issue on float64/complex128, see jax gotcha.
If the shape is not consistent, one can convert the shape by tc.backend.reshape
.
If the device is not consistent, one can move the tensor between devices by tc.backend.device_move
.
AD一致性#
TF和JAX后端对复值函数的微分规则的管理方式不同(实际上是复共轭)。参见讨论 tensorflow issue。
在TensorCircuit中,我们目前使AD的差异透明,即在切换后端时,复值函数的AD行为和结果可能不同,并由相应后端框架的本质行为决定。所有与AD相关的操作,如 grad
或者 jacrev
都可能受到影响。因此,用户在TensorCircuit中以后端无关的方式处理复值函数的AD时必须小心。
参考不同后端的不同模式下计算Jacobian的示例脚本:jacobian_cal.py。另外请参考下面的代码:
bks = ["tensorflow", "jax"]
n = 2
for bk in bks:
print(bk, "backend")
with tc.runtime_backend(bk) as K:
def wfn(params):
c = tc.Circuit(n)
for i in range(n):
c.H(i)
for i in range(n):
c.rz(i, theta=params[i])
c.rx(i, theta=params[i])
return K.real(c.expectation_ps(z=[0])+c.expectation_ps(z=[1]))
print(K.grad(wfn)(K.ones([n], dtype="complex64"))) # default
print(K.grad(wfn)(K.ones([n], dtype="float32")))
# tensorflow backend
# tf.Tensor([0.90929717+0.9228758j 0.90929717+0.9228758j], shape=(2,), dtype=complex64)
# tf.Tensor([0.90929717 0.90929717], shape=(2,), dtype=float32)
# jax backend
# [0.90929747-0.9228759j 0.90929747-0.9228759j]
# [0.90929747 0.90929747]