可微量子架构搜索#

概述#

本教程演示了如何利用 TensorCircuit 提供的高级计算功能,例如 jitvmap 来超级有效地模拟可微量子架构搜索(DQAS)算法,其中具有不同结构的量子电路的集合可以同时编译模拟。 [WIP note]

设置#

[1]:
import numpy as np
import tensorcircuit as tc
import tensorflow as tf
[2]:
K = tc.set_backend("tensorflow")
ctype, rtype = tc.set_dtype("complex128")

问题描述#

任务是找到 GHZ 状态的状态准备电路 \(\vert \text{GHZ}_N\rangle = \frac{1}{\sqrt{2}}\left(\vert 0^N\rangle +\vert 1^N\rangle \right)\)。我们为 \(N=2\) 演示准备了一个包含 rx0、rx1、ry0、ry1、rz0、rz1、cnot01、cnot10 的门池。 在八个门中,有六个是参数化的。

[3]:
def rx0(theta):
    return K.kron(
        K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._x_matrix, K.eye(2)
    )


def rx1(theta):
    return K.kron(
        K.eye(2), K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._x_matrix
    )


def ry0(theta):
    return K.kron(
        K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._y_matrix, K.eye(2)
    )


def ry1(theta):
    return K.kron(
        K.eye(2), K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._y_matrix
    )


def rz0(theta):
    return K.kron(
        K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._z_matrix, K.eye(2)
    )


def rz1(theta):
    return K.kron(
        K.eye(2), K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._z_matrix
    )


def cnot01():
    return K.cast(K.convert_to_tensor(tc.gates._cnot_matrix), ctype)


def cnot10():
    return K.cast(
        K.convert_to_tensor(
            np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
        ),
        ctype,
    )


ops_repr = ["rx0", "rx1", "ry0", "ry1", "rz0", "rz1", "cnot01", "cnot10"]
[4]:
n, p, ch = 2, 3, 8
# 量子比特数、层数、操作池大小

target = tc.array_to_tensor(np.array([1, 0, 0, 1.0]) / np.sqrt(2.0))
# 目标波函数,我们这里使用 GHZ2 状态作为目标函数


def ansatz(params, structures):
    c = tc.Circuit(n)
    params = K.cast(params, ctype)
    structures = K.cast(structures, ctype)
    for i in range(p):
        c.any(
            0,
            1,
            unitary=structures[i, 0] * rx0(params[i, 0])
            + structures[i, 1] * rx1(params[i, 1])
            + structures[i, 2] * ry0(params[i, 2])
            + structures[i, 3] * ry1(params[i, 3])
            + structures[i, 4] * rz0(params[i, 4])
            + structures[i, 5] * rz1(params[i, 5])
            + structures[i, 6] * cnot01()
            + structures[i, 7] * cnot10(),
        )
    s = c.state()
    loss = K.sum(K.abs(target - s))
    return loss


vag1 = K.jit(K.vvag(ansatz, argnums=0, vectorized_argnums=1))

概率系综方法#

这种方法更加实用和实验相关,并且与参考文献 1 中描述的算法相同,尽管我们在这里使用高级 vmap 来加速具有不同结构的电路的仿真。

[5]:
def sampling_from_structure(structures, batch=1):
    prob = K.softmax(K.real(structures), axis=-1)
    return np.array([np.random.choice(ch, p=K.numpy(prob[i])) for i in range(p)])


@K.jit
def best_from_structure(structures):
    return K.argmax(structures, axis=-1)


@K.jit
def nmf_gradient(structures, oh):
    """
    根据朴素平均场概率模型计算蒙特卡洛梯度
    """
    choice = K.argmax(oh, axis=-1)
    prob = K.softmax(K.real(structures), axis=-1)
    indices = K.transpose(K.stack([K.cast(tf.range(p), "int64"), choice]))
    prob = tf.gather_nd(prob, indices)
    prob = K.reshape(prob, [-1, 1])
    prob = K.tile(prob, [1, ch])

    return tf.tensor_scatter_nd_add(
        tf.cast(-prob, dtype=ctype),
        indices,
        tf.ones([p], dtype=ctype),
    )


nmf_gradient_vmap = K.vmap(nmf_gradient, vectorized_argnums=1)
[6]:
verbose = False
epochs = 400
batch = 256
lr = tf.keras.optimizers.schedules.ExponentialDecay(0.06, 100, 0.5)
structure_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(0.12))
network_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(lr))
nnp = K.implicit_randn(stddev=0.02, shape=[p, 6], dtype=rtype)
stp = K.implicit_randn(stddev=0.02, shape=[p, 8], dtype=rtype)
avcost1 = 0
for epoch in range(epochs):  # 更新结构参数的迭代
    avcost2 = avcost1
    costl = []
    batched_stuctures = K.onehot(
        np.stack([sampling_from_structure(stp) for _ in range(batch)]), num=8
    )
    infd, gnnp = vag1(nnp, batched_stuctures)
    gs = nmf_gradient_vmap(stp, batched_stuctures)  # \nabla lnp
    gstp = [K.cast((infd[i] - avcost2), ctype) * gs[i] for i in range(infd.shape[0])]
    gstp = K.real(K.sum(gstp, axis=0) / infd.shape[0])
    avcost1 = K.sum(infd) / infd.shape[0]
    nnp = network_opt.update(gnnp, nnp)
    stp = structure_opt.update(gstp, stp)

    if epoch % 40 == 0 or epoch == epochs - 1:
        print("----------epoch %s-----------" % epoch)
        print(
            "batched average loss: ",
            np.mean(avcost1),
        )

        if verbose:
            print(
                "strcuture parameter: \n",
                stp.numpy(),
                "\n network parameter: \n",
                nnp.numpy(),
            )

        cand_preset = best_from_structure(stp)
        print("best candidates so far:", [ops_repr[i] for i in cand_preset])
        print(
            "corresponding weights for each gate:",
            [K.numpy(nnp[j, i]) if i < 6 else 0.0 for j, i in enumerate(cand_preset)],
        )
WARNING:tensorflow:Using a while_loop for converting GatherNd
WARNING:tensorflow:Using a while_loop for converting TensorScatterAdd
----------epoch 0-----------
batched average loss:  1.486862041224946
best candidates so far: ['rz1', 'cnot01', 'rx0']
corresponding weights for each gate: [0.04850114379068718, 0.0, 0.05130625869908137]
----------epoch 40-----------
batched average loss:  1.0129558433713033
best candidates so far: ['rx0', 'rx1', 'cnot01']
corresponding weights for each gate: [0.027262624897770482, 0.027810247234826772, 0.0]
----------epoch 80-----------
batched average loss:  0.05192747113248927
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7929784361217006, 0.028373579732963325, 0.0]
----------epoch 120-----------
batched average loss:  0.031656973667466226
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7917033798904415, 0.02709852348238091, 0.0]
----------epoch 160-----------
batched average loss:  0.028017594123095527
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7898273038100041, 0.02406457071696495, 0.0]
----------epoch 200-----------
batched average loss:  0.029086134952175734
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7878800008021832, 0.020169898669812416, 0.0]
----------epoch 240-----------
batched average loss:  0.02272153644755242
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7860795359107594, 0.016568960455492835, 0.0]
----------epoch 280-----------
batched average loss:  0.019205161854778285
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7854269350973763, 0.013422528595912621, 0.0]
----------epoch 320-----------
batched average loss:  0.015424930560900666
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7854255410759101, 0.010762703562019423, 0.0]
----------epoch 360-----------
batched average loss:  0.012287067999120332
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7854212965432693, 0.008565353750462018, 0.0]
----------epoch 399-----------
batched average loss:  0.009789006724779316
best candidates so far: ['ry0', 'rx1', 'cnot01']
corresponding weights for each gate: [-0.7922338874583758, -4.046275452326831e-05, 0.0]

直接优化结构参数#

无论如何,由于我们是用数值模拟,所以可以直接优化结构参数,省略超级电路是否是幺正的,这种方法在某些场景下可以更快更可靠。

[7]:
def ansatz2(params, structures):
    c = tc.Circuit(n)
    params = K.cast(params, ctype)
    structures = K.softmax(structures, axis=-1)
    structures = K.cast(structures, ctype)
    for i in range(p):
        c.any(
            0,
            1,
            unitary=structures[i, 0] * rx0(params[i, 0])
            + structures[i, 1] * rx1(params[i, 1])
            + structures[i, 2] * ry0(params[i, 2])
            + structures[i, 3] * ry1(params[i, 3])
            + structures[i, 4] * rz0(params[i, 4])
            + structures[i, 5] * rz1(params[i, 5])
            + structures[i, 6] * cnot01()
            + structures[i, 7] * cnot10(),
        )
    s = c.state()
    s /= K.norm(s)
    loss = K.sum(K.abs(target - s))
    return loss


vag2 = K.jit(K.value_and_grad(ansatz2, argnums=(0, 1)))
[8]:
verbose = True
epochs = 700
lr = tf.keras.optimizers.schedules.ExponentialDecay(0.05, 200, 0.5)
structure_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(0.04))
network_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(lr))
nnp = K.implicit_randn(stddev=0.02, shape=[p, 6], dtype=rtype)
stp = K.implicit_randn(stddev=0.02, shape=[p, 8], dtype=rtype)
for epoch in range(epochs):
    infd, (gnnp, gstp) = vag2(nnp, stp)

    nnp = network_opt.update(gnnp, nnp)
    stp = structure_opt.update(gstp, stp)
    if epoch % 70 == 0 or epoch == epochs - 1:
        print("----------epoch %s-----------" % epoch)
        print(
            "batched average loss: ",
            np.mean(infd),
        )
        if verbose:
            print(
                "strcuture parameter: \n",
                stp.numpy(),
                "\n network parameter: \n",
                nnp.numpy(),
            )

        cand_preset = best_from_structure(stp)
        print("best candidates so far:", [ops_repr[i] for i in cand_preset])
        print(
            "corresponding weights for each gate:",
            [K.numpy(nnp[j, i]) if i < 6 else 0.0 for j, i in enumerate(cand_preset)],
        )
----------epoch 0-----------
batched average loss:  1.3024341605187928
strcuture parameter:
 [[ 0.00265054  0.04495954  0.05265605  0.04751008  0.03309468  0.02743368
   0.03382795 -0.06647121]
 [ 0.03544281  0.03207712  0.03629811  0.0266235   0.03264895  0.03198189
   0.03505167 -0.03449981]
 [ 0.0304648   0.07042194  0.03075206  0.02515865  0.02984363  0.00955019
   0.07527341 -0.05831911]]
 network parameter:
 [[-0.0380125   0.0688923   0.04393423  0.04205065  0.06243917  0.03672062]
 [-0.05277717  0.04834309  0.05176114  0.07030034  0.02983666  0.04821408]
 [-0.04095011  0.0393773   0.03383929  0.06559557  0.03458135  0.02436751]]
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [0.043934227235354874, 0.05176113831452516, 0.0]
----------epoch 70-----------
batched average loss:  1.0078726220234666
strcuture parameter:
 [[ 0.34119556  0.37154559  0.2781548   0.37689646  1.79624262  1.78939153
   1.79791154 -0.3958881 ]
 [-1.04375011 -0.1280161  -0.98656309  0.35601588  1.79624062  1.7944131
   1.80069633 -0.36391841]
 [ 0.05278476  0.40486479  0.33074328  0.35455104  1.79289574  1.77260117
   1.83987217 -0.38773815]]
 network parameter:
 [[ 0.03213363 -0.00113532 -0.02629044  0.44615806 -0.00711866 -0.03271277]
 [ 0.01737624 -0.02169861 -0.01841347  0.47440825  0.01052075 -0.02137655]
 [ 0.02919352 -0.03064528 -0.03628025  0.46970421 -0.03434029 -0.04519493]]
best candidates so far: ['cnot01', 'cnot01', 'cnot01']
corresponding weights for each gate: [0.0, 0.0, 0.0]
----------epoch 140-----------
batched average loss:  0.8974790925982725
strcuture parameter:
 [[-0.60734424 -0.71178177  1.75016478  0.37059946  2.29304101  1.40087053
   2.75041722 -0.4812161 ]
 [-3.20848853 -2.62803529 -1.0799243   0.34964202  1.75222537  0.96340588
   4.59441388 -0.44922864]
 [-1.14052853 -0.10976557  0.9998582   0.34848638  3.81458301  3.38821681
   2.51193413 -0.38172792]]
 network parameter:
 [[-0.03052873 -0.00288839 -1.08840853  0.51536679 -0.00833568 -0.03375284]
 [ 0.0216315  -0.02917968  0.8216509   0.54353084  0.00946658 -0.02120069]
 [ 0.02581133 -0.03281417  0.94900358  0.46842014 -0.0355813  -0.04637678]]
best candidates so far: ['cnot01', 'cnot01', 'rz0']
corresponding weights for each gate: [0.0, 0.0, -0.03558129647764995]
----------epoch 210-----------
batched average loss:  0.06833171337421318
strcuture parameter:
 [[-1.46101866 -1.56562139  2.60460926  0.36697705  1.43861808  0.54649088
   1.8959831  -0.6664916 ]
 [-2.8599476  -3.71447893 -0.1444609  -0.31801788  0.70398108  0.12590137
   5.44896172 -0.64032593]
 [-0.29063004  0.46966802  1.85427644  0.44130728  4.56540152  4.13933904
   1.26524942 -0.47396024]]
 network parameter:
 [[-0.80254223 -0.07223631 -1.72113169  0.65719836 -0.00717114 -0.03248183]
 [ 0.86880067 -0.98161163  1.73279897  0.68543186  0.57173565  0.81901641]
 [ 0.41384514  0.03460214  1.4409846   0.26020133 -0.03450978 -0.04528429]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.7211316854515077, 0.0, -0.03450978363305436]
----------epoch 280-----------
batched average loss:  0.07287093721912785
strcuture parameter:
 [[-1.46462962 -1.5694426   2.60832482  0.36325644  1.43489835  0.54276147
   1.89226573 -0.66280103]
 [-1.57089047 -3.66510778 -0.14837051 -0.32242832  0.70013882  0.1239224
   5.45278654 -0.63670874]
 [-0.28665073  0.50527464  1.85793113  0.43721505  4.56217177  4.13630294
   1.26107282 -0.47016776]]
 network parameter:
 [[-0.80172748 -0.0712512  -1.71930255  0.6591615  -0.00531003 -0.030584  ]
 [ 0.8833277  -1.26937984  1.73185803  0.68715766  0.57394473  0.8184381 ]
 [ 0.41194181  0.03270603  1.4387577   0.26174242 -0.03266689 -0.04343292]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.7193025540909423, 0.0, -0.032666892917796356]
----------epoch 350-----------
batched average loss:  0.0763633455796077
strcuture parameter:
 [[-1.46759968 -1.57262214  2.61142979  0.36014718  1.43179046  0.53964649
   1.88915951 -0.65972411]
 [-1.34936013 -3.5935837  -0.15090151 -0.32639326  0.69677033  0.10483803
   5.45606073 -0.63438681]
 [-0.28282805  0.54032205  1.86109945  0.43217854  4.55947529  4.13345958
   1.25730198 -0.46703484]]
 network parameter:
 [[-0.80018367 -0.06928467 -1.71726239  0.66128651 -0.00323225 -0.0284642 ]
 [ 0.88958794 -1.36344637  1.73074888  0.68928184  0.57647947  0.81753049]
 [ 0.40992401  0.03064141  1.43664452  0.26328958 -0.03061012 -0.04136483]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.7172623928298167, 0.0, -0.030610123400026196]
----------epoch 420-----------
batched average loss:  0.07933778661926846
strcuture parameter:
 [[-1.47009725 -1.57533451  2.61409128  0.35748195  1.42912701  0.53697792
   1.88649718 -0.65709335]
 [-1.19469005 -3.49291914 -0.15236124 -0.32995762  0.69380539  0.07513886
   5.45890732 -0.63296266]
 [-0.27952556  0.57900165  1.86385667  0.42698093  4.55720421  4.13093395
   1.25401354 -0.46439162]]
 network parameter:
 [[-7.98493487e-01 -6.71330183e-02 -1.71528444e+00  6.63308869e-01
  -1.20663132e-03 -2.63889444e-02]
 [ 8.92122476e-01 -1.41065660e+00  1.72949508e+00  6.91390756e-01
   5.79085213e-01  8.16347571e-01]
 [ 4.07963360e-01  2.86428770e-02  1.43467895e+00  2.64831277e-01
  -2.86101270e-02 -3.93481586e-02]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.7152844386878365, 0.0, -0.028610127013962768]
----------epoch 490-----------
batched average loss:  0.08254545024351018
strcuture parameter:
 [[-1.47222071e+00 -1.57769163e+00  2.61641114e+00  3.55158679e-01
   1.42680584e+00  5.34652980e-01  1.88417675e+00 -6.54806876e-01]
 [-9.73250463e-01 -3.16903830e+00 -3.18559749e-03 -3.33158739e-01
   8.55776706e-01  5.24787309e-02  5.46137377e+00 -4.79944986e-01]
 [-4.31555592e-01  6.29503046e-01  1.86558679e+00  4.22659216e-01
   4.55614569e+00  4.13032405e+00  1.25003608e+00 -4.62133108e-01]]
 network parameter:
 [[-7.96868537e-01 -6.41205534e-02 -1.71348758e+00  6.65124861e-01
   7.01296452e-04 -2.44123401e-02]
 [ 8.95067083e-01 -1.46634629e+00  1.72823707e+00  6.93124762e-01
   6.19045003e-01  8.16911575e-01]
 [ 3.68577179e-01 -1.06859075e-02  1.43203849e+00  2.65575517e-01
  -2.67782067e-02 -3.74965029e-02]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.7134875797200289, 0.0, -0.026778206681234634]
----------epoch 560-----------
batched average loss:  0.08507962845391319
strcuture parameter:
 [[-1.47404605e+00 -1.57976616e+00  2.61845726e+00  3.53109374e-01
   1.42475889e+00  5.32603322e-01  1.88213024e+00 -6.52796788e-01]
 [-9.74391719e-01 -2.60249357e+00 -5.45311032e-03 -3.36090881e-01
   8.87457853e-01  5.09795396e-02  5.46352386e+00 -4.77818480e-01]
 [-4.31981604e-01  7.01277774e-01  1.86614168e+00  4.17460986e-01
   4.55685554e+00  4.13219105e+00  1.24506389e+00 -4.60189080e-01]]
 network parameter:
 [[-0.7953961  -0.05746862 -1.7119146   0.66669705  0.00291374 -0.0224405 ]
 [ 0.90083188 -1.50990732  1.72710663  0.69459133  0.61788545  0.82003669]
 [ 0.37017675 -0.00904591  1.42870116  0.26514657 -0.02513608 -0.03583994]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.711914597657612, 0.0, -0.025136080292754256]
----------epoch 630-----------
batched average loss:  0.0874301658835076
strcuture parameter:
 [[-1.47562848 -1.58159874  2.62027778  0.35128611  1.42293816  0.53078116
   1.88030961 -0.65101441]
 [-0.97538969 -2.02211517 -0.00750621 -0.33876198  0.90002101  0.04975063
   5.46545544 -0.47588949]
 [-0.43177273  0.78495553  1.86618092  0.41182654  4.55815983  4.13468705
   1.24131582 -0.45851101]]
 network parameter:
 [[-0.79411737 -0.03262083 -1.71056973  0.66801607  0.00931137 -0.02061058]
 [ 0.90559687 -1.52299082  1.72605402  0.69584073  0.61673658  0.82270295]
 [ 0.37157164 -0.00763778  1.42576784  0.26501218 -0.02367246 -0.03437195]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.7105697331825656, 0.0, -0.02367245788118412]
----------epoch 699-----------
batched average loss:  0.07954282768319362
strcuture parameter:
 [[-1.38724436 -1.4934299   2.53210121  0.43945868  1.51111454  0.61895662
   1.96848611 -0.7392194 ]
 [-0.88692788 -1.19022046  0.08057521 -0.25156948  0.81347111 -0.04509805
   5.3775582  -0.56406862]
 [-0.34354263  0.94817808  1.77696399  0.48982178  4.64871759  4.22638916
   1.32761078 -0.54681732]]
 network parameter:
 [[-0.78310707  0.01573153 -1.69951606  0.67901289  0.0447718  -0.00935101]
 [ 0.91920164 -1.51646333  1.7152693   0.70682227  0.60494688  0.8352184 ]
 [ 0.38342071  0.00423554  1.41389299  0.27478806 -0.01255172 -0.02324825]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.6995160647885235, 0.0, -0.01255171733214709]

最后的微调#

对于获得的电路布局,我们可以进一步调整电路权重,使目标函数更接近于零。

[9]:
chosen_structure = K.onehot(np.array([2, 4, 6]), num=8)
chosen_structure = K.reshape(chosen_structure, [1, p, ch])
chosen_structure
[9]:
<tf.Tensor: shape=(1, 3, 8), dtype=float32, numpy=
array([[[0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.]]], dtype=float32)>
[10]:
network_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(1e-3))
nnp = K.implicit_randn(stddev=0.02, shape=[p, 6], dtype=rtype)
verbose = True
epochs = 600
for epoch in range(epochs):
    infd, gnnp = vag1(nnp, chosen_structure)
    nnp = network_opt.update(gnnp, nnp)
    if epoch % 60 == 0 or epoch == epochs - 1:
        print(epoch, "loss: ", K.numpy(infd[0]))
0 loss:  1.004872758871745
60 loss:  0.9679200431227549
120 loss:  0.9091748060127385
180 loss:  0.8302632245154631
240 loss:  0.73297561645977
300 loss:  0.6183436622858383
360 loss:  0.4874390992051161
420 loss:  0.34168453047914643
480 loss:  0.18299822849737177
540 loss:  0.013923344772669728
599 loss:  0.0016133833518836463

参考资料#

  1. https://arxiv.org/pdf/2010.08561.pdf