Differentiable Quantum Architecture Search#
Overview#
This tutorial demonstrates how to utilize the advanced computational features provided by TensorCircuit such as jit
and vmap
to super efficiently simulate the differentiable quantum architecture search (DQAS) algorithm, where an ensemble of quantum circuits with different structures can be compiled to simulate at the same time.
[WIP note]
Setup#
[1]:
import numpy as np
import tensorcircuit as tc
import tensorflow as tf
[2]:
K = tc.set_backend("tensorflow")
ctype, rtype = tc.set_dtype("complex128")
Problem Description#
The task is to find the state preparation circuit for GHZ state \(\vert \text{GHZ}_N\rangle = \frac{1}{\sqrt{2}}\left(\vert 0^N\rangle +\vert 1^N\rangle \right)\). We prepare a gate pool with rx0, rx1, ry0, ry1, rz0, rz1, cnot01, cnot10 for the \(N=2\) demo. Amongst the eight gates, six are parameterized.
[3]:
def rx0(theta):
return K.kron(
K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._x_matrix, K.eye(2)
)
def rx1(theta):
return K.kron(
K.eye(2), K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._x_matrix
)
def ry0(theta):
return K.kron(
K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._y_matrix, K.eye(2)
)
def ry1(theta):
return K.kron(
K.eye(2), K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._y_matrix
)
def rz0(theta):
return K.kron(
K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._z_matrix, K.eye(2)
)
def rz1(theta):
return K.kron(
K.eye(2), K.cos(theta) * K.eye(2) + 1.0j * K.sin(theta) * tc.gates._z_matrix
)
def cnot01():
return K.cast(K.convert_to_tensor(tc.gates._cnot_matrix), ctype)
def cnot10():
return K.cast(
K.convert_to_tensor(
np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
),
ctype,
)
ops_repr = ["rx0", "rx1", "ry0", "ry1", "rz0", "rz1", "cnot01", "cnot10"]
[4]:
n, p, ch = 2, 3, 8
# number of qubits, number of layers, size of operation pool
target = tc.array_to_tensor(np.array([1, 0, 0, 1.0]) / np.sqrt(2.0))
# target wavefunction, we here use GHZ2 state as the objective target
def ansatz(params, structures):
c = tc.Circuit(n)
params = K.cast(params, ctype)
structures = K.cast(structures, ctype)
for i in range(p):
c.any(
0,
1,
unitary=structures[i, 0] * rx0(params[i, 0])
+ structures[i, 1] * rx1(params[i, 1])
+ structures[i, 2] * ry0(params[i, 2])
+ structures[i, 3] * ry1(params[i, 3])
+ structures[i, 4] * rz0(params[i, 4])
+ structures[i, 5] * rz1(params[i, 5])
+ structures[i, 6] * cnot01()
+ structures[i, 7] * cnot10(),
)
s = c.state()
loss = K.sum(K.abs(target - s))
return loss
vag1 = K.jit(K.vvag(ansatz, argnums=0, vectorized_argnums=1))
Probability Ensemble Approach#
This approach is more practical and experimental relevant and is the same algorithm described in Ref.1, though we here use advanced vmap to accelerate the simulation of circuits with different structures.
[5]:
def sampling_from_structure(structures, batch=1):
prob = K.softmax(K.real(structures), axis=-1)
return np.array([np.random.choice(ch, p=K.numpy(prob[i])) for i in range(p)])
@K.jit
def best_from_structure(structures):
return K.argmax(structures, axis=-1)
@K.jit
def nmf_gradient(structures, oh):
"""
compute the Monte Carlo gradient with respect of naive mean-field probabilistic model
"""
choice = K.argmax(oh, axis=-1)
prob = K.softmax(K.real(structures), axis=-1)
indices = K.transpose(K.stack([K.cast(tf.range(p), "int64"), choice]))
prob = tf.gather_nd(prob, indices)
prob = K.reshape(prob, [-1, 1])
prob = K.tile(prob, [1, ch])
return tf.tensor_scatter_nd_add(
tf.cast(-prob, dtype=ctype),
indices,
tf.ones([p], dtype=ctype),
)
nmf_gradient_vmap = K.vmap(nmf_gradient, vectorized_argnums=1)
[6]:
verbose = False
epochs = 400
batch = 256
lr = tf.keras.optimizers.schedules.ExponentialDecay(0.06, 100, 0.5)
structure_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(0.12))
network_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(lr))
nnp = K.implicit_randn(stddev=0.02, shape=[p, 6], dtype=rtype)
stp = K.implicit_randn(stddev=0.02, shape=[p, 8], dtype=rtype)
avcost1 = 0
for epoch in range(epochs): # iteration to update strcuture param
avcost2 = avcost1
costl = []
batched_stuctures = K.onehot(
np.stack([sampling_from_structure(stp) for _ in range(batch)]), num=8
)
infd, gnnp = vag1(nnp, batched_stuctures)
gs = nmf_gradient_vmap(stp, batched_stuctures) # \nabla lnp
gstp = [K.cast((infd[i] - avcost2), ctype) * gs[i] for i in range(infd.shape[0])]
gstp = K.real(K.sum(gstp, axis=0) / infd.shape[0])
avcost1 = K.sum(infd) / infd.shape[0]
nnp = network_opt.update(gnnp, nnp)
stp = structure_opt.update(gstp, stp)
if epoch % 40 == 0 or epoch == epochs - 1:
print("----------epoch %s-----------" % epoch)
print(
"batched average loss: ",
np.mean(avcost1),
)
if verbose:
print(
"strcuture parameter: \n",
stp.numpy(),
"\n network parameter: \n",
nnp.numpy(),
)
cand_preset = best_from_structure(stp)
print("best candidates so far:", [ops_repr[i] for i in cand_preset])
print(
"corresponding weights for each gate:",
[K.numpy(nnp[j, i]) if i < 6 else 0.0 for j, i in enumerate(cand_preset)],
)
WARNING:tensorflow:Using a while_loop for converting GatherNd
WARNING:tensorflow:Using a while_loop for converting TensorScatterAdd
----------epoch 0-----------
batched average loss: 1.438692604002888
best candidates so far: ['cnot01', 'rx0', 'rx1']
corresponding weights for each gate: [0.0, -0.049711242696246834, 0.0456804722145847]
----------epoch 40-----------
batched average loss: 1.0024311791127296
best candidates so far: ['cnot01', 'ry0', 'cnot01']
corresponding weights for each gate: [0.0, -0.1351106165832465, 0.0]
----------epoch 80-----------
batched average loss: 0.09550699673720528
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.06370593607560585, -0.7355997299177472, 0.0]
----------epoch 120-----------
batched average loss: 0.0672150785213724
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.062430880135008346, -0.7343246757666638, 0.0]
----------epoch 160-----------
batched average loss: 0.07052086338808516
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.060554804305087445, -0.7324486014485383, 0.0]
----------epoch 200-----------
batched average loss: 0.06819711768556835
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.05860750144346523, -0.7305012995010937, 0.0]
----------epoch 240-----------
batched average loss: 0.05454652406620351
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.05680703664615186, -0.728700835323507, 0.0]
----------epoch 280-----------
batched average loss: 0.047745385543626825
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.05680097807715014, -0.7286947772784904, 0.0]
----------epoch 320-----------
batched average loss: 0.039626618064439574
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.05679499116702013, -0.7286887907723886, 0.0]
----------epoch 360-----------
batched average loss: 0.036450806118657045
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.056789547021157315, -0.7286833469676062, 0.0]
----------epoch 399-----------
batched average loss: 0.012538933640035648
best candidates so far: ['ry0', 'ry0', 'cnot01']
corresponding weights for each gate: [-0.06360204206353537, -0.7354958422632526, 0.0]
Directly Optimize the Structure Parameters#
Since we are using numerical simulation anyway, we can directly optimize the structure parameter and omit whether the super circuit is unitary or not, this approach can be faster and more reliable for some scenarios.
[7]:
def ansatz2(params, structures):
c = tc.Circuit(n)
params = K.cast(params, ctype)
structures = K.softmax(structures, axis=-1)
structures = K.cast(structures, ctype)
for i in range(p):
c.any(
0,
1,
unitary=structures[i, 0] * rx0(params[i, 0])
+ structures[i, 1] * rx1(params[i, 1])
+ structures[i, 2] * ry0(params[i, 2])
+ structures[i, 3] * ry1(params[i, 3])
+ structures[i, 4] * rz0(params[i, 4])
+ structures[i, 5] * rz1(params[i, 5])
+ structures[i, 6] * cnot01()
+ structures[i, 7] * cnot10(),
)
s = c.state()
s /= K.norm(s)
loss = K.sum(K.abs(target - s))
return loss
vag2 = K.jit(K.value_and_grad(ansatz2, argnums=(0, 1)))
[8]:
verbose = True
epochs = 700
lr = tf.keras.optimizers.schedules.ExponentialDecay(0.05, 200, 0.5)
structure_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(0.04))
network_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(lr))
nnp = K.implicit_randn(stddev=0.02, shape=[p, 6], dtype=rtype)
stp = K.implicit_randn(stddev=0.02, shape=[p, 8], dtype=rtype)
for epoch in range(epochs):
infd, (gnnp, gstp) = vag2(nnp, stp)
nnp = network_opt.update(gnnp, nnp)
stp = structure_opt.update(gstp, stp)
if epoch % 70 == 0 or epoch == epochs - 1:
print("----------epoch %s-----------" % epoch)
print(
"batched average loss: ",
np.mean(infd),
)
if verbose:
print(
"strcuture parameter: \n",
stp.numpy(),
"\n network parameter: \n",
nnp.numpy(),
)
cand_preset = best_from_structure(stp)
print("best candidates so far:", [ops_repr[i] for i in cand_preset])
print(
"corresponding weights for each gate:",
[K.numpy(nnp[j, i]) if i < 6 else 0.0 for j, i in enumerate(cand_preset)],
)
----------epoch 0-----------
batched average loss: 1.3046788213442395
strcuture parameter:
[[ 0.07621179 0.04934165 0.04669995 0.04737751 0.02036102 0.01170415
0.03786593 -0.05644197]
[ 0.01168381 0.0561013 0.02979136 0.03134415 0.03763557 0.03739249
0.03408754 -0.05335854]
[ 0.03540374 0.03219197 0.01680129 0.02014464 0.06939972 0.02393527
0.04619596 -0.01844729]]
network parameter:
[[-0.0584098 0.04281717 0.0642035 0.06008445 0.0357175 0.05512457]
[-0.07067937 0.04410743 0.03608519 0.03465959 0.02446072 0.06917318]
[-0.01337738 0.04776898 0.04278249 0.04917169 0.0495427 0.01059102]]
best candidates so far: ['rx0', 'rx1', 'rz0']
corresponding weights for each gate: [-0.058409803714939854, 0.04410743113093344, 0.04954270315507654]
----------epoch 70-----------
batched average loss: 1.0081966098666586
strcuture parameter:
[[-0.91750096 0.35057522 0.32585577 0.37681816 1.77239369 1.7734987
1.80143958 -0.38591221]
[ 0.30087524 0.28764993 0.36971695 0.36078872 1.79887933 1.47542633
1.79490296 -0.38283427]
[ 0.29950339 0.32101711 -0.07372448 0.34959339 1.83486426 1.78887106
1.81320642 -0.34792317]]
network parameter:
[[ 0.01163284 -0.02749067 -0.00602475 0.46422017 -0.03365732 -0.01443091]
[-0.00057541 -0.02624807 -0.03408587 0.43879875 -0.04520759 -0.00055711]
[ 0.05673025 0.03099979 -0.02736317 0.45331194 -0.02026327 -0.00559595]]
best candidates so far: ['cnot01', 'rz0', 'rz0']
corresponding weights for each gate: [0.0, -0.045207589104267296, -0.02026326781055693]
----------epoch 140-----------
batched average loss: 0.8049806880722175
strcuture parameter:
[[-3.20900567 -2.18126972 1.96173331 0.3704988 0.75310085 2.01979348
2.47701794 -0.37965676]
[-0.78487034 -1.05072503 0.83960507 0.35409074 1.49913186 0.4284363
4.58858068 -0.37664102]
[ 0.72348068 0.29661214 0.82121041 0.34328667 4.57946006 3.79373413
2.24252671 -0.3416766 ]]
network parameter:
[[-5.93268249e-04 -4.03543595e-02 -1.13260135e+00 4.62883177e-01
-3.47753230e-02 -1.57096245e-02]
[-1.38210543e-03 -5.03624409e-02 1.02006945e+00 4.37465879e-01
-4.64645263e-02 -1.16956649e-03]
[-7.80346264e-02 1.90816551e-02 1.09724554e+00 4.51972627e-01
-2.15345680e-02 -6.84665987e-03]]
best candidates so far: ['cnot01', 'cnot01', 'rz0']
corresponding weights for each gate: [0.0, 0.0, -0.02153456797370576]
----------epoch 210-----------
batched average loss: 0.041816948869616476
strcuture parameter:
[[-3.86458991 -2.84058112 2.62327171 0.26992388 0.09167012 1.35827717
1.81549048 -0.56415243]
[-1.16314411 -1.7698344 1.49466411 -0.30614419 0.75064439 -0.31409853
5.25000534 -0.65623059]
[ 1.4704075 0.89799938 2.01474589 2.50046978 4.8946084 4.44647834
1.49549043 -0.52420796]]
network parameter:
[[ 0.00716229 0.0950563 -1.62490102 0.60459966 -0.033863 -0.01472524]
[ 0.41329341 0.02296645 1.58326833 0.57927215 -0.04604745 -0.05234586]
[-0.07409766 0.08796055 -0.2881097 -0.52346262 -0.02053635 -0.00585734]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.624901021386579, 0.0, -0.020536353581361112]
----------epoch 280-----------
batched average loss: 0.04771541732805661
strcuture parameter:
[[-3.86686803 -2.8436989 2.62698311 0.26657346 0.0879849 1.35457084
1.81178177 -0.67868932]
[-1.33926046 -1.75120967 1.4909566 -0.18080598 0.75530859 -0.30731494
5.25370236 -0.67797002]
[ 1.47961761 0.93984054 2.12875762 2.49693907 4.78860574 4.56031727
1.50105437 -3.90111934]]
network parameter:
[[ 0.01376387 0.10062571 -1.62306954 0.60290458 -0.0321182 -0.01292187]
[ 0.41014329 0.10543278 1.58481867 0.66429872 -0.04488787 -0.0548457 ]
[-0.07315826 0.09047629 -0.35272068 -0.52529957 -0.01871782 -0.00404166]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.6230695444527814, 0.0, -0.018717818357922293]
----------epoch 350-----------
batched average loss: 0.0484244468649333
strcuture parameter:
[[-3.86889078 -2.84635759 2.63008504 0.26367828 0.08490232 1.35147264
1.80868181 -0.68180282]
[-1.61872015 -1.73606594 1.48784963 -0.16422792 0.75890549 -0.30250011
5.25680704 -0.70078063]
[ 1.48953381 0.96377125 2.13183873 2.49380904 4.79193079 4.5635497
1.5054478 -5.6630325 ]]
network parameter:
[[ 0.02014668 0.10330406 -1.62102814 0.60086296 -0.03016668 -0.01090957]
[ 0.40649692 0.12322429 1.58685415 0.67753112 -0.04356077 -0.05752607]
[-0.0730366 0.09201651 -0.35067966 -0.52733454 -0.01668937 -0.00201604]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.6210281393366774, 0.0, -0.01668936553578817]
----------epoch 420-----------
batched average loss: 0.0490371292665724
strcuture parameter:
[[-3.8707214 -2.84868677 2.63274431 0.26116667 0.08225763 1.34881617
1.80602401 -0.68447817]
[-2.17422677 -1.72638998 1.48517837 -0.14718937 0.76130497 -0.30044208
5.25948165 -0.72889795]
[ 1.50031397 0.986817 2.13447811 2.49111713 4.79485557 4.56643982
1.50815833 -6.42399688]]
network parameter:
[[ 2.66201104e-02 1.04232454e-01 -1.61904956e+00 5.98884773e-01
-2.82809595e-02 -8.96136472e-03]
[ 4.01114366e-01 1.42266730e-01 1.58916732e+00 6.88026030e-01
-4.23682835e-02 -6.01395817e-02]
[-7.30273413e-02 9.34775776e-02 -3.48701372e-01 -5.29307121e-01
-1.47240949e-02 -5.37945560e-05]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.6190495562985034, 0.0, -0.014724094945736954]
----------epoch 490-----------
batched average loss: 0.04976228840362948
strcuture parameter:
[[-3.87241212 -2.85077089 2.63506255 0.25896325 0.07995022 1.3465
1.80370686 -0.68681685]
[-2.94364254 -1.72476469 1.48284202 -0.13112958 0.76233067 -0.30191621
5.26182547 -0.76243791]
[ 1.5130618 1.0086917 2.13677828 2.4887622 4.79747567 4.56906223
1.50813331 -6.85905145]]
network parameter:
[[ 0.03313986 0.1034762 -1.6172524 0.59708835 -0.02657844 -0.00719578]
[ 0.39208856 0.1630666 1.59170906 0.69656645 -0.04144895 -0.06259256]
[-0.07305937 0.09493765 -0.34690445 -0.53109883 -0.01294077 0.00172629]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.617252401393024, 0.0, -0.012940768995802935]
----------epoch 560-----------
batched average loss: 0.05065420046477945
strcuture parameter:
[[-3.8739842 -2.85266121 2.6371075 0.25701232 0.0779132 1.34445655
1.80166269 -0.68888562]
[-3.87431144 -1.73479083 1.48077359 -0.1180455 0.76187864 -0.30734926
5.26390441 -0.7996279 ]
[ 1.52854329 1.02822466 2.13880714 2.48667702 4.79985267 4.57146409
1.50375206 -7.1518221 ]]
network parameter:
[[ 0.03952625 0.09952938 -1.6156793 0.59551608 -0.02510117 -0.0056556 ]
[ 0.37625454 0.18679017 1.59450355 0.70332975 -0.04084047 -0.0651834 ]
[-0.07341076 0.09584067 -0.3453315 -0.53266713 -0.01138222 0.00328135]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.6156792973551577, 0.0, -0.01138222012800648]
----------epoch 630-----------
batched average loss: 0.05169300354431061
strcuture parameter:
[[-3.87544608 -2.85438213 2.63892729 0.25527243 0.0760992 1.34263784
1.79984346 -0.69073143]
[-4.81271904 -1.75890412 1.47892571 -0.1106167 0.75992863 -0.31691451
5.26576503 -0.83861308]
[ 1.54468553 1.04577833 2.14061233 2.48481415 4.80202942 4.57367936
1.49248445 -7.36638329]]
network parameter:
[[ 0.04556425 0.08358514 -1.61433443 0.59417203 -0.02385289 -0.00434529]
[ 0.35393466 0.2231968 1.59768193 0.70856295 -0.04056008 -0.06806482]
[-0.07428592 0.0956737 -0.3439867 -0.53400786 -0.01005294 0.00460687]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.614334431628148, 0.0, -0.010052938889800701]
----------epoch 699-----------
batched average loss: 0.10256012079470897
strcuture parameter:
[[-3.787228 -2.7662122 2.55075055 0.34350642 0.16427612 1.43081451
1.88802021 -0.60259055]
[-5.6073258 -2.03961941 1.56706312 -0.37366319 0.66712936 -0.41905248
5.17764218 -0.96685202]
[ 1.46844149 0.80497766 2.05243854 2.57293907 4.71424946 4.48595762
1.38208153 -7.52953678]]
network parameter:
[[ 4.11010856e-02 9.05029846e-03 -1.60328084e+00 5.83119412e-01
-1.28978327e-02 6.67130794e-03]
[ 3.37676048e-01 5.40067470e-01 1.61148200e+00 7.22459586e-01
-3.43741776e-02 -6.15320200e-02]
[-6.44290250e-02 8.86161369e-02 -3.32933747e-01 -5.45057551e-01
9.83673124e-04 1.56395047e-02]]
best candidates so far: ['ry0', 'cnot01', 'rz0']
corresponding weights for each gate: [-1.6032808418929712, 0.0, 0.0009836731240883082]
Final Fine-tune#
For the obtained circuit layout we can further adjust the circuit weights to make the objective more close to zero.
[9]:
chosen_structure = K.onehot(np.array([2, 4, 6]), num=8)
chosen_structure = K.reshape(chosen_structure, [1, p, ch])
chosen_structure
[9]:
<tf.Tensor: shape=(1, 3, 8), dtype=float32, numpy=
array([[[0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0.]]], dtype=float32)>
[10]:
network_opt = tc.backend.optimizer(tf.keras.optimizers.Adam(1e-3))
nnp = K.implicit_randn(stddev=0.02, shape=[p, 6], dtype=rtype)
verbose = True
epochs = 600
for epoch in range(epochs):
infd, gnnp = vag1(nnp, chosen_structure)
nnp = network_opt.update(gnnp, nnp)
if epoch % 60 == 0 or epoch == epochs - 1:
print(epoch, "loss: ", K.numpy(infd[0]))
0 loss: 0.9827084438054802
60 loss: 0.9449745688150044
120 loss: 0.8850948396917335
180 loss: 0.8048454837720991
240 loss: 0.706158632509899
300 loss: 0.5901794549931197
360 loss: 0.45808014651166296
420 loss: 0.3113751664914397
480 loss: 0.1520672098147883
540 loss: 0.0014714944860031312
599 loss: 0.0032763286672428237