TensorCircuit: 常见错误 🔪#

虽然在TensorCircuit中速度很快,但是你必须小心,尤其是在AD和JIT兼容性方面。

Jit 兼容性#

非向量输入或者变化形状的向量输入#

输入必须是张量形式,且输入张量的形状必须固定,否则会重新编译,这是非常耗时的。因此,如果有输入参数是非张量或者变化形状的张量,且经常变化,不建议使用jit。

K = tc.set_backend("tensorflow")

@K.jit
def f(a):
    print("compiling")
    return 2*a

f(K.ones([2]))
# compiling
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([2.+0.j, 2.+0.j], dtype=complex64)>

f(K.zeros([2]))
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.+0.j, 0.+0.j], dtype=complex64)>

f(K.ones([3]))
# compiling
# <tf.Tensor: shape=(3,), dtype=complex64, numpy=array([2.+0.j, 2.+0.j, 2.+0.j], dtype=complex64)>

混合使用numpy和ML后端API#

为了使函数可jit和可AD,函数中的每个操作都应该通过ML后端(tc.backend API或者直接调用后端API tf 或者 jax)。这是因为ML后端必须创建计算图来"进行AD和JIT转换。对于numpy操作,它们只会在jit编译阶段被调用(第一次运行)。

K = tc.set_backend("tensorflow")

@K.jit
def f(a):
    return np.dot(a, a)

f(K.ones([2]))
# NotImplementedError: Cannot convert a symbolic Tensor (a:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

如果你确定numpy调用的行为是你期望的,那么在jit函数中调用numpy是有帮助的。

K = tc.set_backend("tensorflow")

@K.jit
def f(a):
    print("compiling")
    n = a.shape[0]
    m = int(np.log(n)/np.log(2))
    return K.reshape(a, [2 for _ in range(m)])

f(K.ones([4]))
# compiling
# <tf.Tensor: shape=(2, 2), dtype=complex64, numpy=
# array([[1.+0.j, 1.+0.j],
#        [1.+0.j, 1.+0.j]], dtype=complex64)>

f(K.zeros([4]))
# <tf.Tensor: shape=(2, 2), dtype=complex64, numpy=
# array([[0.+0.j, 0.+0.j],
#        [0.+0.j, 0.+0.j]], dtype=complex64)>

f(K.zeros([2]))
# compiling
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.+0.j, 0.+0.j], dtype=complex64)>

if下的list append#

在if条件基于张量值的情况下,将内容附加到Python列表中会导致错误的结果。实际上,两个分支的值都会被附加到列表中。参见下面的例子。

K = tc.set_backend("tensorflow")

@K.jit
def f(a):
    l = []
    one = K.ones([])
    zero = K.zeros([])
    if a > 0:
        l.append(one)
    else:
        l.append(zero)
    return l

f(-K.ones([], dtype="float32"))

# [<tf.Tensor: shape=(), dtype=complex64, numpy=(1+0j)>,
# <tf.Tensor: shape=(), dtype=complex64, numpy=0j>]

上面的代码直接为Jax后端引发了``ConcretizationTypeError``异常,因为Jax jit不支持张量值if条件。

类似地,必须小心地应用条件门。

K = tc.set_backend("tensorflow")

@K.jit
def f():
    c = tc.Circuit(1)
    c.h(0)
    a = c.cond_measure(0)
    if a > 0.5:
        c.x(0)
    else:
        c.z(0)
    return c.state()

f()
# InaccessibleTensorError: tf.Graph captured an external symbolic tensor.

# The correct implementation is

@K.jit
def f():
    c = tc.Circuit(1)
    c.h(0)
    a = c.cond_measure(0)
    c.conditional_gate(a, [tc.gates.z(), tc.gates.x()], 0)
    return c.state()

f()
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.99999994+0.j, 0.        +0.j], dtype=complex64)>

Tensor variables consistency#

All tensor variables' backend (tf vs jax vs ..), dtype (float vs complex), shape and device (cpu vs gpu) must be compatible/consistent.

Inspect the backend, dtype, shape and device using the following codes.

for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
    with tc.runtime_backend(backend):
        a = tc.backend.ones([2, 3])
        print("tensor backend:", tc.interfaces.which_backend(a))
        print("tensor dtype:", tc.backend.dtype(a))
        print("tensor shape:", tc.backend.shape_tuple(a))
        print("tensor device:", tc.backend.device(a))

If the backend is inconsistent, one can convert the tensor backend via tensorcircuit.interfaces.tensortrans.general_args_to_backend().

for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
    with tc.runtime_backend(backend):
        a = tc.backend.ones([2, 3])
        print("tensor backend:", tc.interfaces.which_backend(a))
        b = tc.interfaces.general_args_to_backend(a, target_backend="jax", enable_dlpack=False)
        print("tensor backend:", tc.interfaces.which_backend(b))

If the dtype is inconsistent, one can convert the tensor dtype using tc.backend.cast.

for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
    with tc.runtime_backend(backend):
        a = tc.backend.ones([2, 3])
        print("tensor dtype:", tc.backend.dtype(a))
        b = tc.backend.cast(a, dtype="float64")
        print("tensor dtype:", tc.backend.dtype(b))

Also note the jax issue on float64/complex128, see jax gotcha.

If the shape is not consistent, one can convert the shape by tc.backend.reshape.

If the device is not consistent, one can move the tensor between devices by tc.backend.device_move.

AD一致性#

TF和JAX后端对复值函数的微分规则的管理方式不同(实际上是复共轭)。参见讨论 tensorflow issue

在TensorCircuit中,我们目前使AD的差异透明,即在切换后端时,复值函数的AD行为和结果可能不同,并由相应后端框架的本质行为决定。所有与AD相关的操作,如 grad 或者 jacrev 都可能受到影响。因此,用户在TensorCircuit中以后端无关的方式处理复值函数的AD时必须小心。

参考不同后端的不同模式下计算Jacobian的示例脚本:jacobian_cal.py。另外请参考下面的代码:

bks = ["tensorflow", "jax"]
n = 2
for bk in bks:
    print(bk, "backend")
    with tc.runtime_backend(bk) as K:
        def wfn(params):
            c = tc.Circuit(n)
            for i in range(n):
                c.H(i)
            for i in range(n):
                c.rz(i, theta=params[i])
                c.rx(i, theta=params[i])
            return K.real(c.expectation_ps(z=[0])+c.expectation_ps(z=[1]))
        print(K.grad(wfn)(K.ones([n], dtype="complex64"))) # default
        print(K.grad(wfn)(K.ones([n], dtype="float32")))

# tensorflow backend
# tf.Tensor([0.90929717+0.9228758j 0.90929717+0.9228758j], shape=(2,), dtype=complex64)
# tf.Tensor([0.90929717 0.90929717], shape=(2,), dtype=float32)
# jax backend
# [0.90929747-0.9228759j 0.90929747-0.9228759j]
# [0.90929747 0.90929747]