Utilizing vmap in Quantum Circuit Simulations#

Overview#

We introduce vmap, the advanced feature of the modern machine learning library, to quantum circuit simulations. By vmapping different ingredients of quantum circuit simulation, we can implement variational quantum algorithms with high efficiency.

It is worth noting that in the following use cases, vmap is supported together with jit and AD which renders highly efficient differentiable simulation.

The ingredients that support vmap paradigm are shown in the following figure. vmap ingredients

We have two different types of APIs for vmap, the first one is vmap while the second one is vectorized_value_and_grad, aka, vvag. The latter can also return the gradient information over a batch of the different circuits.

If batch evaluation of gradients as well as function values is required, then this can be done via vectorized_value_and_grad. In the simplest case, consider a function \(f(x,y)\) where \(x\in R^p,y\in R^q\) are both vectors, and one wishes to evaluate both \(f(x,y)\) and \(\sum_x\nabla_y f(x,y) = \sum_x\left ( \frac{\partial f(x,y_1)}{\partial y_1},\ldots, \frac{\partial f(x,y_q)}{\partial y_q}\right )^\top\) over a batch \(x_1, x_2,\ldots, x_k\) of inputs \(x\). This is achieved by creating a new, vectorized value-and-gradient function :

[1]:
%%latex
\begin{equation}
f_{vvg}\left( \begin{pmatrix} \leftarrow x_1 \rightarrow\\ \vdots \\ \leftarrow x_k \rightarrow\end{pmatrix}, y \right) =
\begin{pmatrix} \begin{pmatrix}f(x_1, y) \\ \vdots \\
f(x_k,y)\end{pmatrix},\sum_{i=1}^k \nabla_y f(x_i,y) \end{pmatrix}
\end{equation}
\begin{equation} f_{vvg}\left( \begin{pmatrix} \leftarrow x_1 \rightarrow\\ \vdots \\ \leftarrow x_k \rightarrow\end{pmatrix}, y \right) = \begin{pmatrix} \begin{pmatrix}f(x_1, y) \\ \vdots \\ f(x_k,y)\end{pmatrix},\sum_{i=1}^k \nabla_y f(x_i,y) \end{pmatrix} \end{equation}

which takes as zeroth argument the batched inputs expressed as a \(k\times p\) tensor, and as first argument the variables we wish to differentiate with respect to. The outputs are a vector of function values evaluated at all points \((x_i,y)\), and the gradient averaged over all those points.

Setup#

[2]:
import numpy as np
import tensorcircuit as tc

tc.set_backend("tensorflow")
print(tc.__version__)

nwires = 5
nlayers = 2
batch = 6
0.0.220509

vmap the Input States#

Use case: batch processing of input states in quantum machine learning task.

For applications of batched input state processing, please see MNIST QML tutorial.

Minimal Example#

[3]:
def f(inputs, weights):
    c = tc.Circuit(nwires, inputs=inputs)
    c = tc.templates.blocks.example_block(c, weights, nlayers=nlayers)
    loss = c.expectation([tc.gates.z(), [2]])
    loss = tc.backend.real(loss)
    return loss


f_vg = tc.backend.jit(tc.backend.vvag(f, argnums=1, vectorized_argnums=0))
f_vg(tc.backend.ones([batch, 2**nwires]), tc.backend.ones([2 * nlayers, nwires]))
[3]:
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([10.88678, 10.88678, 10.88678, 10.88678, 10.88678, 10.88678],
       dtype=float32)>,
 <tf.Tensor: shape=(4, 5), dtype=complex64, numpy=
 array([[ 0.0000000e+00+1.3064140e+02j, -1.1444092e-05+1.3064142e+02j,
          0.0000000e+00+1.3064140e+02j,  0.0000000e+00+1.3064139e+02j,
          0.0000000e+00+0.0000000e+00j],
        [-1.9073486e-06-5.1765751e-06j, -5.1105431e+01-5.7347143e-07j,
         -8.1339760e+01-6.6063179e+01j, -5.1105446e+01+3.3477118e-06j,
         -7.6293945e-06+1.5500746e-07j],
        [ 0.0000000e+00+8.4607742e+01j, -1.3292285e+02+1.1209973e+02j,
         -1.3292284e+02+1.1209971e+02j,  1.5258789e-05+8.4607750e+01j,
          0.0000000e+00+0.0000000e+00j],
        [ 1.9073486e-06+5.9908474e+01j, -1.5258789e-05-1.9285599e+01j,
         -8.1339752e+01+3.8049275e-06j,  3.8146973e-06-1.9285591e+01j,
         -9.5367432e-06+5.9908482e+01j]], dtype=complex64)>)

vmap the Circuit Weights#

Use case: batched VQE, where different random initialization parameters are optimized simultaneously.

For application on batched VQE, please refer TFIM VQE tutorial.

Minimal Example#

[4]:
def f(weights):
    c = tc.Circuit(nwires)
    c = tc.templates.blocks.example_block(c, weights, nlayers=nlayers)
    loss = c.expectation([tc.gates.z(), [2]])
    loss = tc.backend.real(loss)
    return loss


f_vg = tc.backend.jit(tc.backend.vvag(f, argnums=0, vectorized_argnums=0))
f_vg(tc.backend.ones([batch, 2 * nlayers, nwires]))
[4]:
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([-2.9802322e-08, -2.9802322e-08, -2.9802322e-08, -2.9802322e-08,
        -2.9802322e-08, -2.9802322e-08], dtype=float32)>,
 <tf.Tensor: shape=(6, 4, 5), dtype=complex64, numpy=
 array([[[ 1.1614500e-08+2.1480869e-08j, -9.2439478e-10-1.8808342e-08j,
           2.6397275e-08-8.0511313e-09j,  2.7981415e-08-1.6564460e-08j,
           0.0000000e+00+0.0000000e+00j],
         [ 4.1470027e-09-1.9918247e-08j, -7.7494953e-09+9.5806874e-09j,
           0.0000000e+00-1.3076999e-08j,  1.2109957e-09+3.2571617e-08j,
          -1.0110498e-08+1.6951747e-08j],
         [ 1.1614500e-08-1.0295013e-08j, -1.6102263e-08+2.5077789e-08j,
          -3.2204525e-08+5.0155577e-08j,  1.8346144e-08-3.5633683e-09j,
           0.0000000e+00+0.0000000e+00j],
         [-7.1439974e-09-3.3070933e-09j,  0.0000000e+00+6.8412485e-09j,
           1.4287995e-08-9.5050003e-09j, -7.1439974e-09-6.5384995e-09j,
          -2.3792996e-08-7.8779987e-09j]],

        [[ 1.1614500e-08+2.1480869e-08j, -9.2439478e-10-1.8808342e-08j,
           2.6397275e-08-8.0511313e-09j,  2.7981415e-08-1.6564460e-08j,
           0.0000000e+00+0.0000000e+00j],
         [ 4.1470027e-09-1.9918247e-08j, -7.7494953e-09+9.5806874e-09j,
           0.0000000e+00-1.3076999e-08j,  1.2109957e-09+3.2571617e-08j,
          -1.0110498e-08+1.6951747e-08j],
         [ 1.1614500e-08-1.0295013e-08j, -1.6102263e-08+2.5077789e-08j,
          -3.2204525e-08+5.0155577e-08j,  1.8346144e-08-3.5633683e-09j,
           0.0000000e+00+0.0000000e+00j],
         [-7.1439974e-09-3.3070933e-09j,  0.0000000e+00+6.8412485e-09j,
           1.4287995e-08-9.5050003e-09j, -7.1439974e-09-6.5384995e-09j,
          -2.3792996e-08-7.8779987e-09j]],

        [[ 1.1614500e-08+2.1480869e-08j, -9.2439478e-10-1.8808342e-08j,
           2.6397275e-08-8.0511313e-09j,  2.7981415e-08-1.6564460e-08j,
           0.0000000e+00+0.0000000e+00j],
         [ 4.1470027e-09-1.9918247e-08j, -7.7494953e-09+9.5806874e-09j,
           0.0000000e+00-1.3076999e-08j,  1.2109957e-09+3.2571617e-08j,
          -1.0110498e-08+1.6951747e-08j],
         [ 1.1614500e-08-1.0295013e-08j, -1.6102263e-08+2.5077789e-08j,
          -3.2204525e-08+5.0155577e-08j,  1.8346144e-08-3.5633683e-09j,
           0.0000000e+00+0.0000000e+00j],
         [-7.1439974e-09-3.3070933e-09j,  0.0000000e+00+6.8412485e-09j,
           1.4287995e-08-9.5050003e-09j, -7.1439974e-09-6.5384995e-09j,
          -2.3792996e-08-7.8779987e-09j]],

        [[ 1.1614500e-08+2.1480869e-08j, -9.2439478e-10-1.8808342e-08j,
           2.6397275e-08-8.0511313e-09j,  2.7981415e-08-1.6564460e-08j,
           0.0000000e+00+0.0000000e+00j],
         [ 4.1470027e-09-1.9918247e-08j, -7.7494953e-09+9.5806874e-09j,
           0.0000000e+00-1.3076999e-08j,  1.2109957e-09+3.2571617e-08j,
          -1.0110498e-08+1.6951747e-08j],
         [ 1.1614500e-08-1.0295013e-08j, -1.6102263e-08+2.5077789e-08j,
          -3.2204525e-08+5.0155577e-08j,  1.8346144e-08-3.5633683e-09j,
           0.0000000e+00+0.0000000e+00j],
         [-7.1439974e-09-3.3070933e-09j,  0.0000000e+00+6.8412485e-09j,
           1.4287995e-08-9.5050003e-09j, -7.1439974e-09-6.5384995e-09j,
          -2.3792996e-08-7.8779987e-09j]],

        [[ 1.1614500e-08+2.1480869e-08j, -9.2439478e-10-1.8808342e-08j,
           2.6397275e-08-8.0511313e-09j,  2.7981415e-08-1.6564460e-08j,
           0.0000000e+00+0.0000000e+00j],
         [ 4.1470027e-09-1.9918247e-08j, -7.7494953e-09+9.5806874e-09j,
           0.0000000e+00-1.3076999e-08j,  1.2109957e-09+3.2571617e-08j,
          -1.0110498e-08+1.6951747e-08j],
         [ 1.1614500e-08-1.0295013e-08j, -1.6102263e-08+2.5077789e-08j,
          -3.2204525e-08+5.0155577e-08j,  1.8346144e-08-3.5633683e-09j,
           0.0000000e+00+0.0000000e+00j],
         [-7.1439974e-09-3.3070933e-09j,  0.0000000e+00+6.8412485e-09j,
           1.4287995e-08-9.5050003e-09j, -7.1439974e-09-6.5384995e-09j,
          -2.3792996e-08-7.8779987e-09j]],

        [[ 1.1614500e-08+2.1480869e-08j, -9.2439478e-10-1.8808342e-08j,
           2.6397275e-08-8.0511313e-09j,  2.7981415e-08-1.6564460e-08j,
           0.0000000e+00+0.0000000e+00j],
         [ 4.1470027e-09-1.9918247e-08j, -7.7494953e-09+9.5806874e-09j,
           0.0000000e+00-1.3076999e-08j,  1.2109957e-09+3.2571617e-08j,
          -1.0110498e-08+1.6951747e-08j],
         [ 1.1614500e-08-1.0295013e-08j, -1.6102263e-08+2.5077789e-08j,
          -3.2204525e-08+5.0155577e-08j,  1.8346144e-08-3.5633683e-09j,
           0.0000000e+00+0.0000000e+00j],
         [-7.1439974e-09-3.3070933e-09j,  0.0000000e+00+6.8412485e-09j,
           1.4287995e-08-9.5050003e-09j, -7.1439974e-09-6.5384995e-09j,
          -2.3792996e-08-7.8779987e-09j]]], dtype=complex64)>)

vmap the Quantum Noise#

Use case: parallel Monte Carlo noise simulation.

For applications that combine vmapped Monte Carlo noise simulation and quantum machine learning task, please see noisy QML script.

Minimal Example#

[5]:
def f(weights, status):
    c = tc.Circuit(nwires)
    c = tc.templates.blocks.example_block(c, weights, nlayers=nlayers)
    for i in range(nwires):
        c.depolarizing(i, px=0.2, py=0.2, pz=0.2, status=status[i])
    loss = c.expectation([tc.gates.x(), [2]])
    loss = tc.backend.real(loss)
    return loss


f_vg = tc.backend.jit(tc.backend.vvag(f, argnums=0, vectorized_argnums=1))


def g(weights):
    status = tc.backend.implicit_randu(shape=[batch, nwires])
    return f_vg(weights, status)


g(tc.backend.ones([2 * nlayers, nwires]))
[5]:
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([ 0.34873545, -0.34873545, -0.34873545, -0.34873545, -0.34873545,
         0.34873545], dtype=float32)>,
 <tf.Tensor: shape=(4, 5), dtype=complex64, numpy=
 array([[-8.8614023e-01-1.7026657e-08j,  5.7763958e-01+2.3834804e-01j,
          5.7763910e-01+2.3834780e-01j, -8.8614047e-01+1.2538894e-07j,
          0.0000000e+00+0.0000000e+00j],
        [ 0.0000000e+00+6.9313288e-01j,  3.6122650e-01-1.1496974e-02j,
         -5.2079970e-01-1.7869800e-01j,  3.6122644e-01-1.1496985e-02j,
         -2.9802322e-08+6.9313288e-01j],
        [-5.9604645e-08-1.0189922e+00j, -3.0850098e-01-1.4861794e-07j,
         -3.0850050e-01-2.2304604e-08j,  5.9604645e-08-1.0189921e+00j,
          0.0000000e+00+0.0000000e+00j],
        [ 0.0000000e+00+3.1868588e-02j, -8.9406967e-08+2.5656950e-01j,
         -2.9802322e-08-1.9999983e+00j, -2.9802322e-08+2.5656945e-01j,
         -1.1920929e-07+3.1868652e-02j]], dtype=complex64)>)

vmap the Circuit Structure#

Use case: differentiable quantum architecture search (DQAS).

For more detail on DQAS application, see DQAS tutorial.

Minimal Example#

[6]:
eye = tc.gates.i().tensor
x = tc.gates.x().tensor
y = tc.gates.y().tensor
z = tc.gates.z().tensor


def f(params, structures):
    c = tc.Circuit(nwires)
    for i in range(nwires):
        c.H(i)
    for j in range(nlayers):
        for i in range(nwires - 1):
            c.cz(i, i + 1)
        for i in range(nwires):
            c.unitary(
                i,
                unitary=structures[i, j, 0]
                * (
                    tc.backend.cos(params[i, j, 0]) * eye
                    + tc.backend.sin(params[i, j, 0]) * x
                )
                + structures[i, j, 1]
                * (
                    tc.backend.cos(params[i, j, 1]) * eye
                    + tc.backend.sin(params[i, j, 1]) * y
                )
                + structures[i, j, 2]
                * (
                    tc.backend.cos(params[i, j, 2]) * eye
                    + tc.backend.sin(params[i, j, 2]) * z
                ),
            )
    loss = c.expectation([tc.gates.z(), (2,)])
    return tc.backend.real(loss)


structures = tc.backend.ones([batch, nwires, nlayers, 3])
params = tc.backend.ones([nwires, nlayers, 3])
f_vg = tc.backend.jit(tc.backend.vvag(f, argnums=0, vectorized_argnums=1))
f_vg(params, structures)
[6]:
(<tf.Tensor: shape=(6,), dtype=float32, numpy=
 array([2.4917054e+08, 2.4917054e+08, 2.4917054e+08, 2.4917054e+08,
        2.4917054e+08, 2.4917054e+08], dtype=float32)>,
 <tf.Tensor: shape=(5, 2, 3), dtype=complex64, numpy=
 array([[[-4.8252989e+08+2.3603376e+07j, -6.4132224e+08+1.1064736e+08j,
          -4.5701562e+08-7.4987272e+07j],
         [-5.4175347e+08+5.2096408e+07j, -5.5254317e+08-4.6495180e+07j,
          -4.5219101e+08-5.6013205e+06j]],

        [[-7.1430163e+08-1.2090212e+08j, -6.2410163e+08-4.1363908e+07j,
          -3.9189485e+08+4.0016840e+06j],
         [-5.8365677e+08+9.4236816e+07j, -5.7693280e+08-9.7727496e+07j,
          -3.9540646e+08+3.4906362e+06j]],

        [[-5.9637555e+08+8.9477632e+07j, -7.6615610e+08+1.1949610e+08j,
          -3.8039136e+08-4.7556400e+07j],
         [-1.1637092e+09-4.0144461e+08j, -1.1735478e+09+4.5104198e+08j,
          -1.5947418e+08+1.5322706e+07j]],

        [[-7.1430170e+08-1.2090210e+08j, -6.2410170e+08-4.1363864e+07j,
          -3.9189485e+08+4.0016840e+06j],
         [-5.8365658e+08+9.4236840e+07j, -5.7693261e+08-9.7727496e+07j,
          -3.9540637e+08+3.4906552e+06j]],

        [[-4.8253002e+08+2.3603400e+07j, -6.4132237e+08+1.1064734e+08j,
          -4.5701565e+08-7.4987248e+07j],
         [-5.4175334e+08+5.2096460e+07j, -5.5254304e+08-4.6495116e+07j,
          -4.5219091e+08-5.6013120e+06j]]], dtype=complex64)>)

vmap the Circuit Measurements#

Use case: accelerating evaluation of Pauli string sum by parallel the parameterized measurement.

For applications on evaluation of parameterized measurements via vmap on large-scale systems, see large-scale vqe example script.

Minimal Example#

[7]:
def f(params, structures):
    c = tc.Circuit(nwires)
    c = tc.templates.blocks.example_block(c, params, nlayers=nlayers)
    loss = tc.templates.measurements.parameterized_measurements(
        c, structures, onehot=True
    )
    return loss


# measure X0 to X3
structures = tc.backend.eye(nwires)
f_vvag = tc.backend.jit(tc.backend.vvag(f, vectorized_argnums=1, argnums=0))
f_vvag(tc.backend.ones([2 * nlayers, nwires]), structures)
WARNING:tensorflow:5 out of the last 5 calls to <function TensorFlowBackend.vectorized_value_and_grad.<locals>.wrapper at 0x7fe6cbed1af0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
[7]:
(<tf.Tensor: shape=(5,), dtype=float32, numpy=
 array([-0.3118263 ,  0.00371493,  0.3487355 ,  0.00371514, -0.31182614],
       dtype=float32)>,
 <tf.Tensor: shape=(4, 5), dtype=complex64, numpy=
 array([[ 1.6707865e+00-0.40178323j, -1.1992662e+00-0.23834792j,
         -1.1992660e+00-0.2383478j ,  1.6707866e+00-0.40178335j,
          0.0000000e+00+0.j        ],
        [-1.8267021e-01-0.6483071j ,  7.7729575e-02+0.58401704j,
         -1.0082662e-01-0.52953976j,  7.7729806e-02+0.58401704j,
         -1.8267024e-01-0.6483072j ],
        [ 1.6707866e+00+0.19420199j, -1.1992658e+00+0.50487465j,
         -1.1992657e+00+0.504875j  ,  1.6707867e+00+0.19420168j,
          0.0000000e+00+0.j        ],
        [ 7.4505806e-09+0.99540246j,  1.4901161e-08+0.7925009j ,
         -7.4505806e-09+0.71156096j, -7.4505806e-09+0.7925008j ,
          2.2351742e-08+0.9954027j ]], dtype=complex64)>)