tensorcircuit.torchnn 源代码

PyTorch nn Module wrapper for quantum function

from typing import Any, Callable, Sequence, Tuple, Union

import torch

from .cons import backend
from .interfaces import torch_interface
from .utils import is_sequence

Tensor = Any

[文档]class QuantumNet(torch.nn.Module):
[文档] def __init__( self, f: Callable[..., Any], weights_shape: Sequence[Tuple[int, ...]], initializer: Union[Any, Sequence[Any]] = None, use_vmap: bool = True, vectorized_argnums: Union[int, Sequence[int]] = 0, use_interface: bool = True, use_jit: bool = True, enable_dlpack: bool = False, ): """ PyTorch nn Module wrapper on quantum function ``f``. :Example: .. code-block:: python K = tc.set_backend("tensorflow") n = 6 nlayers = 2 batch = 2 def qpred(x, weights): c = tc.Circuit(n) for i in range(n): c.rx(i, theta=x[i]) for j in range(nlayers): for i in range(n - 1): c.cnot(i, i + 1) for i in range(n): c.rx(i, theta=weights[2 * j, i]) c.ry(i, theta=weights[2 * j + 1, i]) ypred = K.stack([c.expectation_ps(x=[i]) for i in range(n)]) ypred = K.real(ypred) return ypred ql = tc.torchnn.QuantumNet(qpred, weights_shape=[2*nlayers, n]) ql(torch.ones([batch, n])) :param f: Quantum function with tensor in (input and weights) and tensor out. :type f: Callable[..., Any] :param weights_shape: list of shape tuple for different weights as the non-first parameters for ``f`` :type weights_shape: Sequence[Tuple[int, ...]] :param initializer: function that gives the shape tuple returns torch tensor, defaults to None :type initializer: Union[Any, Sequence[Any]], optional :param use_vmap: whether apply vmap (batch input) on ``f``, defaults to True :type use_vmap: bool, optional :param vectorized_argnums: which position of input should be batched, need to be customized when multiple inputs for the torch model, defaults to be 0. :type vectorized_argnums: Union[int, Sequence[int]] :param use_interface: whether transform ``f`` with torch interface, defaults to True :type use_interface: bool, optional :param use_jit: whether jit ``f``, defaults to True :type use_jit: bool, optional :param enable_dlpack: whether enbale dlpack in interfaces, defaults to False :type enable_dlpack: bool, optional """ super().__init__() if use_vmap: f = backend.vmap(f, vectorized_argnums=vectorized_argnums) if use_interface: f = torch_interface(f, jit=use_jit, enable_dlpack=enable_dlpack) self.f = f self.q_weights = torch.nn.ParameterList() if isinstance(weights_shape[0], int): weights_shape = [weights_shape] if not is_sequence(initializer): initializer = [initializer] for ws, initf in zip(weights_shape, initializer): if initf is None: initf = torch.randn self.q_weights.append(torch.nn.Parameter(initf(ws)))
[文档] def forward(self, *inputs: Tensor) -> Tensor: ypred = self.f(*inputs, *self.q_weights) return ypred
TorchLayer = QuantumNet
[文档]class HardwareNet(QuantumNet): """ PyTorch Layer wrapping quantum function with cloud qpu access (using :py:mod:`tensorcircuit.cloud` module) """
[文档] def __init__( self, f: Callable[..., Any], weights_shape: Sequence[Tuple[int, ...]], initializer: Union[Any, Sequence[Any]] = None, use_vmap: bool = True, ): super().__init__( f, weights_shape, initializer, use_vmap=False, use_interface=False, use_jit=False, ) self.batch_support = use_vmap
[文档] def forward(self, *inputs: Tensor) -> Tensor: if self.batch_support: ypred = [] batch = inputs[0].shape[0] for i in range(batch): inp = tuple([a[i] for a in inputs]) ypred.append(self.f(*inp, *self.q_weights)) ypred = torch.stack(ypred) # type: ignore else: ypred = self.f(*inputs, *self.q_weights) return ypred
TorchHardwareLayer = HardwareNet