MNIST 分类的量子机器学习#

概述#

本教程的目的不是从机器学习的角度来更好地设计用于 MNIST 分类的量子机器学习方法。相反,我们使用一个简单的参数化电路,并演示 tensorcircuit的量子机器学习相关的技术组件。此外,这个 jupyter notebook 绝不代表是量子机器学习的好的实践。 [WIP note]

设置#

[1]:
from functools import partial
import numpy as np
import tensorflow as tf
import jax
from jax.config import config

config.update("jax_enable_x64", True)
from jax import numpy as jnp
import optax
import tensorcircuit as tc
[2]:
tc.set_backend("tensorflow")
tc.set_dtype("complex128")
[2]:
('complex128', 'float64')

数据处理#

我们利用 MNIST 数据并将它们调整为 3*3 以适应 9 量子位电路。 我们使用的测试平台是二进制分类任务,区分数字15。 由于本教程不是关于量子机器学习的良好实践,因此我们将验证集放在一边。 我们只为一个小型演示收集 100 个数据点用于实验。

[3]:
# numpy 数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., np.newaxis] / 255.0


def filter_pair(x, y, a, b):
    keep = (y == a) | (y == b)
    x, y = x[keep], y[keep]
    y = y == a
    return x, y


x_train, y_train = filter_pair(x_train, y_train, 1, 5)
x_train_small = tf.image.resize(x_train, (3, 3)).numpy()
x_train_bin = np.array(x_train_small > 0.5, dtype=np.float32)
x_train_bin = np.squeeze(x_train_bin)[:100]
[4]:
# tensorflow 数据

x_train_tf = tf.reshape(tf.constant(x_train_bin, dtype=tf.float64), [-1, 9])
y_train_tf = tf.constant(y_train[:100], dtype=tf.float64)

# jax 数据

x_train_jax = jnp.array(x_train_bin, dtype=np.float64).reshape([100, -1])
y_train_jax = jnp.array(y_train[:100], dtype=np.float64).reshape([100])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

使用 vectorized_value_and_grad API#

[5]:
nlayers = 3


def qml_loss(x, y, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypred = c.expectation([tc.gates.z(), (4,)])
    ypred = tc.backend.real(ypred)
    ypred = (tc.backend.real(ypred) + 1) / 2.0
    return -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred), ypred
[6]:
def get_qml_vvag():
    qml_vvag = tc.backend.vectorized_value_and_grad(
        qml_loss, argnums=(2,), vectorized_argnums=(0, 1), has_aux=True
    )
    qml_vvag = tc.backend.jit(qml_vvag, static_argnums=(3,))
    return qml_vvag


qml_vvag = get_qml_vvag()
qml_vvag(x_train_tf, y_train_tf, tf.ones([nlayers * 2, 9], dtype=tf.float64), nlayers)
[6]:
((<tf.Tensor: shape=(100,), dtype=float64, numpy=
  array([0.8433698 , 0.56257199, 0.54653163, 0.56257199, 0.82036163,
         0.56257199, 0.56257199, 0.58030506, 0.82036163, 0.56257199,
         0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.54653163,
         0.54653163, 0.56257199, 0.56257199, 0.58030506, 0.82036163,
         0.54653163, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
         0.56257199, 0.56257199, 0.85182866, 0.56257199, 0.82036163,
         0.82036163, 0.56257199, 0.8433698 , 0.56257199, 0.8433698 ,
         0.56257199, 0.85182866, 0.56257199, 0.82036163, 0.54653163,
         0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.8433698 ,
         0.58030506, 0.56257199, 0.82036163, 0.8433698 , 0.8433698 ,
         0.54653163, 0.56257199, 0.82036163, 0.86501404, 0.56257199,
         0.56257199, 0.8433698 , 0.56257199, 0.85182866, 0.82036163,
         0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.56257199,
         0.56257199, 0.82036163, 0.8433698 , 0.8433698 , 0.82036163,
         0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
         0.54653163, 0.86501404, 0.54653163, 0.54653163, 0.82036163,
         0.56257199, 0.54653163, 0.8433698 , 0.54653163, 0.8433698 ,
         0.56257199, 0.56257199, 0.8433698 , 0.82036163, 0.8433698 ,
         0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
         0.82036163, 0.56257199, 0.58030506, 0.8433698 , 0.56257199])>,
  <tf.Tensor: shape=(100,), dtype=float64, numpy=
  array([0.56974181, 0.56974181, 0.57895436, 0.56974181, 0.55972759,
         0.56974181, 0.56974181, 0.55972759, 0.55972759, 0.56974181,
         0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.57895436,
         0.57895436, 0.56974181, 0.56974181, 0.55972759, 0.55972759,
         0.57895436, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
         0.56974181, 0.56974181, 0.57336595, 0.56974181, 0.55972759,
         0.55972759, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
         0.56974181, 0.57336595, 0.56974181, 0.55972759, 0.57895436,
         0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
         0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181,
         0.57895436, 0.56974181, 0.55972759, 0.57895436, 0.56974181,
         0.56974181, 0.56974181, 0.56974181, 0.57336595, 0.55972759,
         0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181,
         0.56974181, 0.55972759, 0.56974181, 0.56974181, 0.55972759,
         0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
         0.57895436, 0.57895436, 0.57895436, 0.57895436, 0.55972759,
         0.56974181, 0.57895436, 0.56974181, 0.57895436, 0.56974181,
         0.56974181, 0.56974181, 0.56974181, 0.55972759, 0.56974181,
         0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
         0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181])>),
 [<tf.Tensor: shape=(6, 9), dtype=float64, numpy=
  array([[ 5.79464357e-02,  1.12182823e-01,  8.13605755e-02,
           1.52611620e-01,  1.13641690e+00, -1.41695736e+00,
           7.62883290e-01,  4.44089210e-16,  3.33066907e-16],
         [ 3.57155554e-02,  1.61488509e-01,  7.62331819e-02,
           1.50335863e-01, -1.10363460e-01, -3.23606686e-01,
          -3.14523756e-01,  1.11022302e-16, -2.22044605e-16],
         [-3.04959149e-02,  6.03271869e-02,  2.47760477e-02,
          -7.61859417e-02,  1.72064441e+00,  1.66891120e+00,
           3.33066907e-16,  1.11022302e-16,  6.66133815e-16],
         [-8.26503466e-03, -6.90338030e-02, -1.07589110e-01,
           1.88650816e-01, -2.68228700e+00, -2.41159987e+00,
           0.00000000e+00,  1.33226763e-15,  6.66133815e-16],
         [-2.22044605e-16,  3.33066907e-16,  5.55111512e-16,
          -3.33066907e-16,  4.02608813e-01, -7.77156117e-16,
           3.33066907e-16,  1.11022302e-16,  3.33066907e-16],
         [-6.66133815e-16,  9.99200722e-16,  2.22044605e-16,
           5.55111512e-16, -1.11152817e+00,  1.11022302e-16,
           4.44089210e-16,  1.11022302e-16,  8.88178420e-16]])>])
[7]:
# %timeit qml_vvag(x_train_tf, y_train_tf, tf.ones([nlayers*2, 9], dtype=tf.float64), nlayers)

Jax 后端兼容性#

[8]:
tc.set_backend("jax")
[8]:
<tensorcircuit.backends.jax_backend.JaxBackend at 0x7ffb04a71820>
[9]:
qml_vvag = get_qml_vvag()
qml_vvag(
    x_train_jax, y_train_jax, jnp.ones([nlayers * 2, 9], dtype=np.float64), nlayers
)
[9]:
((DeviceArray([0.8433698 , 0.56257199, 0.54653163, 0.56257199, 0.82036163,
               0.56257199, 0.56257199, 0.58030506, 0.82036163, 0.56257199,
               0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.54653163,
               0.54653163, 0.56257199, 0.56257199, 0.58030506, 0.82036163,
               0.54653163, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
               0.56257199, 0.56257199, 0.85182866, 0.56257199, 0.82036163,
               0.82036163, 0.56257199, 0.8433698 , 0.56257199, 0.8433698 ,
               0.56257199, 0.85182866, 0.56257199, 0.82036163, 0.54653163,
               0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.8433698 ,
               0.58030506, 0.56257199, 0.82036163, 0.8433698 , 0.8433698 ,
               0.54653163, 0.56257199, 0.82036163, 0.86501404, 0.56257199,
               0.56257199, 0.8433698 , 0.56257199, 0.85182866, 0.82036163,
               0.82036163, 0.56257199, 0.82036163, 0.56257199, 0.56257199,
               0.56257199, 0.82036163, 0.8433698 , 0.8433698 , 0.82036163,
               0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
               0.54653163, 0.86501404, 0.54653163, 0.54653163, 0.82036163,
               0.56257199, 0.54653163, 0.8433698 , 0.54653163, 0.8433698 ,
               0.56257199, 0.56257199, 0.8433698 , 0.82036163, 0.8433698 ,
               0.56257199, 0.56257199, 0.56257199, 0.56257199, 0.56257199,
               0.82036163, 0.56257199, 0.58030506, 0.8433698 , 0.56257199],            dtype=float64),
  DeviceArray([0.56974181, 0.56974181, 0.57895436, 0.56974181, 0.55972759,
               0.56974181, 0.56974181, 0.55972759, 0.55972759, 0.56974181,
               0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.57895436,
               0.57895436, 0.56974181, 0.56974181, 0.55972759, 0.55972759,
               0.57895436, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
               0.56974181, 0.56974181, 0.57336595, 0.56974181, 0.55972759,
               0.55972759, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
               0.56974181, 0.57336595, 0.56974181, 0.55972759, 0.57895436,
               0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
               0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181,
               0.57895436, 0.56974181, 0.55972759, 0.57895436, 0.56974181,
               0.56974181, 0.56974181, 0.56974181, 0.57336595, 0.55972759,
               0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181,
               0.56974181, 0.55972759, 0.56974181, 0.56974181, 0.55972759,
               0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
               0.57895436, 0.57895436, 0.57895436, 0.57895436, 0.55972759,
               0.56974181, 0.57895436, 0.56974181, 0.57895436, 0.56974181,
               0.56974181, 0.56974181, 0.56974181, 0.55972759, 0.56974181,
               0.56974181, 0.56974181, 0.56974181, 0.56974181, 0.56974181,
               0.55972759, 0.56974181, 0.55972759, 0.56974181, 0.56974181],            dtype=float64)),
 (DeviceArray([[ 5.79464357e-02,  1.12182823e-01,  8.13605755e-02,
                 1.52611620e-01,  1.13641690e+00, -1.41695736e+00,
                 7.62883290e-01,  1.01307851e-15, -1.03389519e-15],
               [ 3.57155554e-02,  1.61488509e-01,  7.62331819e-02,
                 1.50335863e-01, -1.10363460e-01, -3.23606686e-01,
                -3.14523756e-01, -4.09394740e-16, -1.65145675e-15],
               [-3.04959149e-02,  6.03271869e-02,  2.47760477e-02,
                -7.61859417e-02,  1.72064441e+00,  1.66891120e+00,
                 6.24500451e-16,  9.71445147e-16, -8.39606162e-16],
               [-8.26503466e-03, -6.90338030e-02, -1.07589110e-01,
                 1.88650816e-01, -2.68228700e+00, -2.41159987e+00,
                -6.66133815e-16,  5.27355937e-16, -7.56339436e-16],
               [-1.08246745e-15,  1.65839564e-15, -2.77555756e-16,
                 1.38777878e-17,  4.02608813e-01, -7.14706072e-16,
                -4.16333634e-17, -9.02056208e-16, -9.57567359e-16],
               [-1.08940634e-15,  3.26128013e-16, -5.34294831e-16,
                 6.93889390e-18, -1.11152817e+00,  3.19189120e-16,
                -1.68615122e-15,  1.24900090e-16,  1.79717352e-15]],            dtype=float64),))
[10]:
# %timeit qml_vvag(x_train_jax, y_train_jax, jnp.ones([nlayers * 2, 9], dtype=np.float64), nlayers)

使用 tf.data 训练模型#

[11]:
# 转换回 tensorflow
tc.set_backend("tensorflow")
qml_vvag = get_qml_vvag()
qml_vvag = tc.backend.jit(qml_vvag, static_argnums=(3,))
[12]:
mnist_data = (
    tf.data.Dataset.from_tensor_slices((x_train_tf, y_train_tf))
    .repeat(200)
    .shuffle(100)
    .batch(32)
)
[13]:
opt = tf.keras.optimizers.Adam(1e-2)
w = tf.Variable(
    initial_value=tf.random.normal(shape=(2 * nlayers, 9), stddev=0.5, dtype=tf.float64)
)
for i, (xs, ys) in zip(range(2000), mnist_data):
    (losses, ypreds), grad = qml_vvag(xs, ys, w, nlayers)
    if i % 20 == 0:
        print(tf.reduce_mean(losses))
        opt.apply_gradients([(grad[0], w)])
tf.Tensor(0.689301607482696, shape=(), dtype=float64)
tf.Tensor(0.6825438352666904, shape=(), dtype=float64)
tf.Tensor(0.6815497367036047, shape=(), dtype=float64)
tf.Tensor(0.6632433448327015, shape=(), dtype=float64)
tf.Tensor(0.6641348270253142, shape=(), dtype=float64)
tf.Tensor(0.6779914200102861, shape=(), dtype=float64)
tf.Tensor(0.6550256969249619, shape=(), dtype=float64)
tf.Tensor(0.6801325087248677, shape=(), dtype=float64)
tf.Tensor(0.6190616725052769, shape=(), dtype=float64)
tf.Tensor(0.6711760566099414, shape=(), dtype=float64)
tf.Tensor(0.6965496746836946, shape=(), dtype=float64)
tf.Tensor(0.6443036572691725, shape=(), dtype=float64)
tf.Tensor(0.6060956714527996, shape=(), dtype=float64)
tf.Tensor(0.6728839286340991, shape=(), dtype=float64)
tf.Tensor(0.6584085272471567, shape=(), dtype=float64)
tf.Tensor(0.6600981577311038, shape=(), dtype=float64)
tf.Tensor(0.6581071758186605, shape=(), dtype=float64)
tf.Tensor(0.6609348320181809, shape=(), dtype=float64)
tf.Tensor(0.5919640703180435, shape=(), dtype=float64)
tf.Tensor(0.6362392080775805, shape=(), dtype=float64)
tf.Tensor(0.6844038809425064, shape=(), dtype=float64)
tf.Tensor(0.6924617230085226, shape=(), dtype=float64)
tf.Tensor(0.6594653043250199, shape=(), dtype=float64)
tf.Tensor(0.7076707818117074, shape=(), dtype=float64)
tf.Tensor(0.6730725215608222, shape=(), dtype=float64)
tf.Tensor(0.6565711271336594, shape=(), dtype=float64)
tf.Tensor(0.6665226844123278, shape=(), dtype=float64)
tf.Tensor(0.6368469891760338, shape=(), dtype=float64)
tf.Tensor(0.6499572506552256, shape=(), dtype=float64)
tf.Tensor(0.6110576844713855, shape=(), dtype=float64)
tf.Tensor(0.6312147945757532, shape=(), dtype=float64)
tf.Tensor(0.6013772883771527, shape=(), dtype=float64)

使用 tf.keras API#

[14]:
from tensorcircuit import keras


def qml_y(x, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypred = c.expectation([tc.gates.z(), (4,)])
    ypred = tc.backend.real(ypred)
    ypred = (tc.backend.real(ypred) + 1) / 2.0
    return ypred


ql = keras.QuantumLayer(partial(qml_y, nlayers=nlayers), [(2 * nlayers, 9)])
[15]:
# 带有 value and grad 范例的 keras 接口


@tf.function
def my_vvag(xs, ys):
    with tf.GradientTape() as tape:
        ypred = ql(xs)
        loss = tf.keras.losses.BinaryCrossentropy()(ys, ypred)
    return loss, tape.gradient(loss, ql.variables)


my_vvag(x_train_tf, y_train_tf)
[15]:
(<tf.Tensor: shape=(), dtype=float64, numpy=0.7179324626922607>,
 [<tf.Tensor: shape=(6, 9), dtype=float64, numpy=
  array([[-1.97741333e-02, -3.24903196e-03, -1.19449484e-02,
           1.34411790e-02, -2.29378194e-03,  9.24968875e-04,
           3.41827505e-04,  1.38777878e-17, -6.93889390e-18],
         [-1.85390086e-02,  3.81940052e-03, -3.05341288e-02,
          -1.79981829e-03, -5.77913396e-02, -3.71762005e-03,
          -5.10097165e-03, -1.71303943e-17, -1.73472348e-18],
         [ 5.04193508e-03, -1.77846516e-02,  2.26429668e-02,
          -1.41076421e-02, -3.13874407e-02,  1.37515418e-03,
           2.08166817e-17,  2.42861287e-17, -1.73472348e-18],
         [ 2.67860892e-02,  1.92311176e-02, -2.44580361e-02,
          -5.08346256e-02, -1.15289797e-02, -8.99461139e-03,
           3.46944695e-18, -5.20417043e-18, -6.93889390e-18],
         [ 9.54097912e-18, -3.46944695e-18,  1.04083409e-17,
          -1.73472348e-18, -2.53960212e-03,  1.31188463e-17,
           5.20417043e-18, -1.38777878e-17, -1.04083409e-17],
         [-2.60208521e-18,  5.20417043e-18, -5.20417043e-18,
           5.20417043e-18, -3.82010017e-03,  1.38777878e-17,
           1.73472348e-18, -2.08166817e-17,  6.93889390e-18]])>])
[16]:
# %timeit my_vvag(x_train_tf, y_train_tf)
[17]:
# keras 接口与 keras 训练范例

model = tf.keras.Sequential([ql])

model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.01),
    metrics=[tf.keras.metrics.BinaryAccuracy()],
)

model.fit(x_train_tf, y_train_tf, batch_size=32, epochs=100)
Epoch 1/100
4/4 [==============================] - 21s 8ms/step - loss: 0.7221 - binary_accuracy: 0.6016
Epoch 2/100
4/4 [==============================] - 0s 7ms/step - loss: 0.7123 - binary_accuracy: 0.6016
Epoch 3/100
4/4 [==============================] - 0s 8ms/step - loss: 0.7039 - binary_accuracy: 0.6562
Epoch 4/100
4/4 [==============================] - 0s 7ms/step - loss: 0.7009 - binary_accuracy: 0.6562
Epoch 5/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6979 - binary_accuracy: 0.6562
Epoch 6/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6957 - binary_accuracy: 0.6016
Epoch 7/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6935 - binary_accuracy: 0.4922
Epoch 8/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6918 - binary_accuracy: 0.6562
Epoch 9/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6910 - binary_accuracy: 0.7109
Epoch 10/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6901 - binary_accuracy: 0.5469
Epoch 11/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6893 - binary_accuracy: 0.6016
Epoch 12/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6883 - binary_accuracy: 0.6562
Epoch 13/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6876 - binary_accuracy: 0.6016
Epoch 14/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6869 - binary_accuracy: 0.5469
Epoch 15/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6865 - binary_accuracy: 0.7109
Epoch 16/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6858 - binary_accuracy: 0.6562
Epoch 17/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6853 - binary_accuracy: 0.5469
Epoch 18/100
4/4 [==============================] - 0s 9ms/step - loss: 0.6847 - binary_accuracy: 0.6016
Epoch 19/100
4/4 [==============================] - 0s 9ms/step - loss: 0.6844 - binary_accuracy: 0.6016
Epoch 20/100
4/4 [==============================] - 0s 9ms/step - loss: 0.6842 - binary_accuracy: 0.5469
Epoch 21/100
4/4 [==============================] - 0s 9ms/step - loss: 0.6841 - binary_accuracy: 0.6016
Epoch 22/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6839 - binary_accuracy: 0.7109
Epoch 23/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6835 - binary_accuracy: 0.6562
Epoch 24/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6829 - binary_accuracy: 0.6016
Epoch 25/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6823 - binary_accuracy: 0.7109
Epoch 26/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6816 - binary_accuracy: 0.6016
Epoch 27/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6811 - binary_accuracy: 0.5469
Epoch 28/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6805 - binary_accuracy: 0.4922
Epoch 29/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6803 - binary_accuracy: 0.6562
Epoch 30/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6799 - binary_accuracy: 0.5469
Epoch 31/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6795 - binary_accuracy: 0.6016
Epoch 32/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6793 - binary_accuracy: 0.5469
Epoch 33/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6789 - binary_accuracy: 0.6016
Epoch 34/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6785 - binary_accuracy: 0.5469
Epoch 35/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6781 - binary_accuracy: 0.6562
Epoch 36/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6775 - binary_accuracy: 0.5469
Epoch 37/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6762 - binary_accuracy: 0.6016
Epoch 38/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6752 - binary_accuracy: 0.6562
Epoch 39/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6736 - binary_accuracy: 0.6016
Epoch 40/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6714 - binary_accuracy: 0.6562
Epoch 41/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6690 - binary_accuracy: 0.6562
Epoch 42/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6658 - binary_accuracy: 0.6016
Epoch 43/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6637 - binary_accuracy: 0.6016
Epoch 44/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6617 - binary_accuracy: 0.6562
Epoch 45/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6596 - binary_accuracy: 0.6016
Epoch 46/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6586 - binary_accuracy: 0.6016
Epoch 47/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6571 - binary_accuracy: 0.6016
Epoch 48/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6561 - binary_accuracy: 0.6562
Epoch 49/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6549 - binary_accuracy: 0.6562
Epoch 50/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6536 - binary_accuracy: 0.6562
Epoch 51/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6536 - binary_accuracy: 0.6562
Epoch 52/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6519 - binary_accuracy: 0.6016
Epoch 53/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6516 - binary_accuracy: 0.7109
Epoch 54/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6504 - binary_accuracy: 0.6016
Epoch 55/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6500 - binary_accuracy: 0.6016
Epoch 56/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6486 - binary_accuracy: 0.5469
Epoch 57/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6468 - binary_accuracy: 0.6016
Epoch 58/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6466 - binary_accuracy: 0.7109
Epoch 59/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6456 - binary_accuracy: 0.6562
Epoch 60/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6446 - binary_accuracy: 0.7109
Epoch 61/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6435 - binary_accuracy: 0.6016
Epoch 62/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6429 - binary_accuracy: 0.7109
Epoch 63/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6417 - binary_accuracy: 0.7109
Epoch 64/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6432 - binary_accuracy: 0.5469
Epoch 65/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6439 - binary_accuracy: 0.7109
Epoch 66/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6430 - binary_accuracy: 0.6016
Epoch 67/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6415 - binary_accuracy: 0.5469
Epoch 68/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6391 - binary_accuracy: 0.6562
Epoch 69/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6375 - binary_accuracy: 0.6016
Epoch 70/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6372 - binary_accuracy: 0.6016
Epoch 71/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6369 - binary_accuracy: 0.5469
Epoch 72/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6367 - binary_accuracy: 0.6016
Epoch 73/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6380 - binary_accuracy: 0.7109
Epoch 74/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6377 - binary_accuracy: 0.6562
Epoch 75/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6365 - binary_accuracy: 0.6562
Epoch 76/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6350 - binary_accuracy: 0.6562
Epoch 77/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6331 - binary_accuracy: 0.6016
Epoch 78/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6331 - binary_accuracy: 0.6562
Epoch 79/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6337 - binary_accuracy: 0.5469
Epoch 80/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6319 - binary_accuracy: 0.6562
Epoch 81/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6317 - binary_accuracy: 0.7109
Epoch 82/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6312 - binary_accuracy: 0.6562
Epoch 83/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6307 - binary_accuracy: 0.6562
Epoch 84/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6326 - binary_accuracy: 0.6016
Epoch 85/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6307 - binary_accuracy: 0.6016
Epoch 86/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6299 - binary_accuracy: 0.6016
Epoch 87/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6288 - binary_accuracy: 0.6016
Epoch 88/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6288 - binary_accuracy: 0.7109
Epoch 89/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6289 - binary_accuracy: 0.6562
Epoch 90/100
4/4 [==============================] - 0s 8ms/step - loss: 0.6273 - binary_accuracy: 0.6562
Epoch 91/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6275 - binary_accuracy: 0.5469
Epoch 92/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6269 - binary_accuracy: 0.7109
Epoch 93/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6269 - binary_accuracy: 0.6016
Epoch 94/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6263 - binary_accuracy: 0.6016
Epoch 95/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6258 - binary_accuracy: 0.6016
Epoch 96/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6256 - binary_accuracy: 0.6562
Epoch 97/100
4/4 [==============================] - 0s 7ms/step - loss: 0.6250 - binary_accuracy: 0.6562
Epoch 98/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6246 - binary_accuracy: 0.7109
Epoch 99/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6240 - binary_accuracy: 0.6562
Epoch 100/100
4/4 [==============================] - 0s 6ms/step - loss: 0.6251 - binary_accuracy: 0.5469
[17]:
<keras.callbacks.History at 0x7ffaf92858b0>

Keras 中的量子经典混合模型#

[18]:
def qml_ys(x, weights, nlayers):
    n = 9
    weights = tc.backend.cast(weights, "complex128")
    x = tc.backend.cast(x, "complex128")
    c = tc.Circuit(n)
    for i in range(n):
        c.rx(i, theta=x[i])
    for j in range(nlayers):
        for i in range(n - 1):
            c.cnot(i, i + 1)
        for i in range(n):
            c.rx(i, theta=weights[2 * j, i])
            c.ry(i, theta=weights[2 * j + 1, i])
    ypreds = []
    for i in range(n):
        ypred = c.expectation([tc.gates.z(), (i,)])
        ypred = tc.backend.real(ypred)
        ypred = (tc.backend.real(ypred) + 1) / 2.0
        ypreds.append(ypred)
    return tc.backend.stack(ypreds)
[19]:
ql = tc.keras.QuantumLayer(partial(qml_ys, nlayers=nlayers), [(2 * nlayers, 9)])
model = tf.keras.Sequential([ql, tf.keras.layers.Dense(1, activation="sigmoid")])
[20]:
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(0.01),
    metrics=[tf.keras.metrics.BinaryAccuracy()],
)

model.fit(x_train_tf, y_train_tf, batch_size=32, epochs=100)
Epoch 1/100
4/4 [==============================] - 24s 14ms/step - loss: 0.9307 - binary_accuracy: 0.3700
Epoch 2/100
4/4 [==============================] - 0s 14ms/step - loss: 0.8286 - binary_accuracy: 0.3700
Epoch 3/100
4/4 [==============================] - 0s 15ms/step - loss: 0.7538 - binary_accuracy: 0.3700
Epoch 4/100
4/4 [==============================] - 0s 14ms/step - loss: 0.7044 - binary_accuracy: 0.3700
Epoch 5/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6796 - binary_accuracy: 0.6300
Epoch 6/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6599 - binary_accuracy: 0.6300
Epoch 7/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6543 - binary_accuracy: 0.6300
Epoch 8/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6559 - binary_accuracy: 0.6300
Epoch 9/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6575 - binary_accuracy: 0.6300
Epoch 10/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6588 - binary_accuracy: 0.6300
Epoch 11/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6587 - binary_accuracy: 0.6300
Epoch 12/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6567 - binary_accuracy: 0.6300
Epoch 13/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6551 - binary_accuracy: 0.6300
Epoch 14/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6540 - binary_accuracy: 0.6300
Epoch 15/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6528 - binary_accuracy: 0.6300
Epoch 16/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6533 - binary_accuracy: 0.6300
Epoch 17/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6540 - binary_accuracy: 0.6300
Epoch 18/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6550 - binary_accuracy: 0.6300
Epoch 19/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6546 - binary_accuracy: 0.6300
Epoch 20/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6538 - binary_accuracy: 0.6300
Epoch 21/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6513 - binary_accuracy: 0.6300
Epoch 22/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6499 - binary_accuracy: 0.6300
Epoch 23/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6497 - binary_accuracy: 0.6300
Epoch 24/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6491 - binary_accuracy: 0.6300
Epoch 25/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6492 - binary_accuracy: 0.6300
Epoch 26/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6493 - binary_accuracy: 0.6300
Epoch 27/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6487 - binary_accuracy: 0.6300
Epoch 28/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6473 - binary_accuracy: 0.6300
Epoch 29/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6470 - binary_accuracy: 0.6300
Epoch 30/100
4/4 [==============================] - 0s 16ms/step - loss: 0.6462 - binary_accuracy: 0.6300
Epoch 31/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6451 - binary_accuracy: 0.6300
Epoch 32/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6439 - binary_accuracy: 0.6300
Epoch 33/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6433 - binary_accuracy: 0.6300
Epoch 34/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6438 - binary_accuracy: 0.6300
Epoch 35/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6413 - binary_accuracy: 0.6300
Epoch 36/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6403 - binary_accuracy: 0.6300
Epoch 37/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6391 - binary_accuracy: 0.6300
Epoch 38/100
4/4 [==============================] - 0s 16ms/step - loss: 0.6388 - binary_accuracy: 0.6300
Epoch 39/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6379 - binary_accuracy: 0.6300
Epoch 40/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6365 - binary_accuracy: 0.6300
Epoch 41/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6352 - binary_accuracy: 0.6300
Epoch 42/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6336 - binary_accuracy: 0.6300
Epoch 43/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6338 - binary_accuracy: 0.6300
Epoch 44/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6358 - binary_accuracy: 0.6300
Epoch 45/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6367 - binary_accuracy: 0.6300
Epoch 46/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6345 - binary_accuracy: 0.6300
Epoch 47/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6303 - binary_accuracy: 0.6300
Epoch 48/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6298 - binary_accuracy: 0.6300
Epoch 49/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6285 - binary_accuracy: 0.6300
Epoch 50/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6280 - binary_accuracy: 0.6300
Epoch 51/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6274 - binary_accuracy: 0.6300
Epoch 52/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6268 - binary_accuracy: 0.6300
Epoch 53/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6262 - binary_accuracy: 0.6300
Epoch 54/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6246 - binary_accuracy: 0.6300
Epoch 55/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6231 - binary_accuracy: 0.6300
Epoch 56/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6228 - binary_accuracy: 0.6300
Epoch 57/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6226 - binary_accuracy: 0.6300
Epoch 58/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6224 - binary_accuracy: 0.6300
Epoch 59/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6228 - binary_accuracy: 0.6900
Epoch 60/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6224 - binary_accuracy: 0.7200
Epoch 61/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6214 - binary_accuracy: 0.7200
Epoch 62/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6183 - binary_accuracy: 0.6300
Epoch 63/100
4/4 [==============================] - 0s 17ms/step - loss: 0.6161 - binary_accuracy: 0.6300
Epoch 64/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6142 - binary_accuracy: 0.6300
Epoch 65/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6131 - binary_accuracy: 0.6300
Epoch 66/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6124 - binary_accuracy: 0.6300
Epoch 67/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6101 - binary_accuracy: 0.6300
Epoch 68/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6117 - binary_accuracy: 0.6600
Epoch 69/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6099 - binary_accuracy: 0.7200
Epoch 70/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6085 - binary_accuracy: 0.7200
Epoch 71/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6070 - binary_accuracy: 0.7200
Epoch 72/100
4/4 [==============================] - 0s 13ms/step - loss: 0.6069 - binary_accuracy: 0.7200
Epoch 73/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6060 - binary_accuracy: 0.7200
Epoch 74/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6040 - binary_accuracy: 0.7200
Epoch 75/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6041 - binary_accuracy: 0.7200
Epoch 76/100
4/4 [==============================] - 0s 15ms/step - loss: 0.6011 - binary_accuracy: 0.6900
Epoch 77/100
4/4 [==============================] - 0s 14ms/step - loss: 0.6005 - binary_accuracy: 0.6300
Epoch 78/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5993 - binary_accuracy: 0.6300
Epoch 79/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5987 - binary_accuracy: 0.6300
Epoch 80/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5982 - binary_accuracy: 0.6300
Epoch 81/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5970 - binary_accuracy: 0.6300
Epoch 82/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5960 - binary_accuracy: 0.6900
Epoch 83/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5937 - binary_accuracy: 0.7200
Epoch 84/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5932 - binary_accuracy: 0.7200
Epoch 85/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5913 - binary_accuracy: 0.7200
Epoch 86/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5904 - binary_accuracy: 0.7200
Epoch 87/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5895 - binary_accuracy: 0.7200
Epoch 88/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5882 - binary_accuracy: 0.7200
Epoch 89/100
4/4 [==============================] - 0s 13ms/step - loss: 0.5873 - binary_accuracy: 0.7200
Epoch 90/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5858 - binary_accuracy: 0.7200
Epoch 91/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5848 - binary_accuracy: 0.7200
Epoch 92/100
4/4 [==============================] - 0s 16ms/step - loss: 0.5839 - binary_accuracy: 0.7200
Epoch 93/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5832 - binary_accuracy: 0.7200
Epoch 94/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5836 - binary_accuracy: 0.7200
Epoch 95/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5848 - binary_accuracy: 0.7200
Epoch 96/100
4/4 [==============================] - 0s 15ms/step - loss: 0.5836 - binary_accuracy: 0.7200
Epoch 97/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5812 - binary_accuracy: 0.7200
Epoch 98/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5795 - binary_accuracy: 0.7200
Epoch 99/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5780 - binary_accuracy: 0.7200
Epoch 100/100
4/4 [==============================] - 0s 14ms/step - loss: 0.5768 - binary_accuracy: 0.7200
[20]:
<keras.callbacks.History at 0x7ffac3142e20>

Jax 中的混合模型#

[21]:
tc.set_backend("jax")
[21]:
<tensorcircuit.backends.jax_backend.JaxBackend at 0x7ffb04a71820>
[22]:
key = jax.random.PRNGKey(42)
key, *subkeys = jax.random.split(key, num=4)
params = {
    "qweights": jax.random.normal(subkeys[0], shape=[nlayers * 2, 9]),
    "cweights:w": jax.random.normal(subkeys[1], shape=[9]),
    "cweights:b": jax.random.normal(subkeys[2], shape=[1]),
}
[23]:
def qml_hybrid_loss(x, y, params, nlayers):
    weights = params["qweights"]
    w = params["cweights:w"]
    b = params["cweights:b"]
    ypred = qml_ys(x, weights, nlayers)
    ypred = tc.backend.reshape(ypred, [-1, 1])
    ypred = w @ ypred + b
    ypred = jax.nn.sigmoid(ypred)
    ypred = ypred[0]
    loss = -y * tc.backend.log(ypred) - (1 - y) * tc.backend.log(1 - ypred)
    return loss
[24]:
qml_hybrid_loss_vag = tc.backend.jit(
    tc.backend.vvag(qml_hybrid_loss, vectorized_argnums=(0, 1), argnums=2),
    static_argnums=3,
)
[25]:
qml_hybrid_loss_vag(x_train_jax, y_train_jax, params, nlayers)
[25]:
(DeviceArray([3.73282398, 0.02421603, 0.02899787, 0.02421603, 4.08996787,
              0.03069481, 0.02421603, 0.01688146, 4.08996787, 0.03069481,
              4.08996787, 0.02421603, 4.08996787, 0.02421603, 0.02899787,
              0.03354042, 0.02421603, 0.02421603, 0.01688146, 4.08996787,
              0.03354042, 0.02421603, 0.02421603, 0.03069481, 0.02421603,
              0.02421603, 0.03069481, 3.73798651, 0.02421603, 3.68810189,
              4.08996787, 0.03069481, 3.73282398, 0.03069481, 3.73282398,
              0.02421603, 3.49674264, 0.02421603, 4.08996787, 0.02899787,
              0.02421603, 0.02421603, 0.03069481, 0.03069481, 3.73282398,
              0.02533775, 0.03069481, 3.68810189, 3.73282398, 3.49896983,
              0.02899787, 0.03069481, 4.08996787, 3.41172721, 0.02421603,
              0.02421603, 3.73282398, 0.02421603, 3.73798651, 3.68810189,
              4.08996787, 0.03069481, 4.08996787, 0.02421603, 0.03069481,
              0.02421603, 3.68810189, 3.49896983, 3.49896983, 4.08996787,
              0.02421603, 0.02421603, 0.02421603, 0.02421603, 0.03069481,
              0.02899787, 3.41172721, 0.03354042, 0.02899787, 3.68810189,
              0.02421603, 0.03354042, 3.73282398, 0.02899787, 3.73282398,
              0.03069481, 0.02421603, 3.73282398, 3.68810189, 3.73282398,
              0.02421603, 0.02421603, 0.03069481, 0.03069481, 0.02421603,
              4.08996787, 0.02421603, 0.01688146, 3.73282398, 0.02421603],            dtype=float64),
 {'cweights:b': DeviceArray([34.49476789], dtype=float64),
  'cweights:w': DeviceArray([16.81782277, 15.05718878, 15.02498328, 23.18351696,
               17.01897109, 16.13466029, 16.26046722, 23.54180309,
               12.0721068 ], dtype=float64),
  'qweights': DeviceArray([[-1.16993912e+01, -6.74730815e+00, -2.27227872e+00,
                -1.08703899e+00,  2.56625721e+00,  1.69462223e+00,
                -4.89847061e+00,  1.62487935e+00,  1.02424785e+01],
               [ 3.29984130e+00, -5.90635608e-01,  2.11407610e+00,
                 3.67096431e-02,  3.32526833e+00, -1.06468920e+00,
                -4.12299772e-01, -7.78105081e+00, -3.38506241e+00],
               [-3.59434442e+00,  3.84548015e+00,  8.50409406e-01,
                -2.66504333e+00,  1.47559967e+00,  1.38536529e+00,
                -1.47291602e-01, -7.32213541e+00,  5.17021200e+00],
               [-1.30975045e+00,  1.83003338e+00,  1.51443252e+00,
                 3.15082430e+00, -4.41767236e+00,  6.25968228e+00,
                 5.96980281e+00,  9.67198061e+00, -1.63091455e+01],
               [-2.24757712e+00, -5.66276080e-01, -1.67376432e+00,
                 1.75249049e-01,  2.77917505e-01,  3.84402979e-02,
                 1.03434679e-01, -4.05760762e-02, -3.33671956e-03],
               [ 3.13599600e+00,  3.85470136e+00,  3.17986238e-01,
                 1.72308312e-01,  5.09749793e+00,  2.90706770e-02,
                -5.59919189e-01,  1.96734688e+00, -6.96372626e-01]],            dtype=float64)})
[26]:
optimizer = optax.adam(5e-3)
opt_state = optimizer.init(params)
for i, (xs, ys) in zip(range(2000), mnist_data):  # 在这里使用 tensorflow 数据加载器
    xs = xs.numpy()
    ys = ys.numpy()
    v, grads = qml_hybrid_loss_vag(xs, ys, params, nlayers)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 30 == 0:
        print(jnp.mean(v))
1.2979572281332594
0.8331012068009501
0.6805939758448183
0.5897353928152392
0.6460840124038746
0.6093143713632384
0.6671721223530598
0.5863347320393952
0.5465362554431986
0.5594138744621404
0.5493311423294576
0.5228166702417829
0.6176455570797168
0.5256494465741394
0.5359881696740493
0.5787532611935906
0.49082340457493323
0.4062487079116086
0.5802733401377229
0.4762524476616207
0.5404245247888219