Source code for tensorcircuit.backends.backend_factory

"""
Backend register
"""

from typing import Any, Dict, Text, Union

import tensornetwork as tn

try:  # old version tn compatiblity
    from tensornetwork.backends import base_backend

    tnbackend = base_backend.BaseBackend

except ImportError:
    from tensornetwork.backends import abstract_backend

    tnbackend = abstract_backend.AbstractBackend

from .numpy_backend import NumpyBackend
from .jax_backend import JaxBackend
from .tensorflow_backend import TensorFlowBackend
from .pytorch_backend import PyTorchBackend
from .cupy_backend import CuPyBackend

bk = Any  # tnbackend

_BACKENDS = {
    "numpy": NumpyBackend,
    "jax": JaxBackend,
    "tensorflow": TensorFlowBackend,
    "pytorch": PyTorchBackend,  # no intention to fully maintain this one
    "cupy": CuPyBackend,  # no intention to fully maintain this one
}

tn.backends.backend_factory._BACKENDS["cupy"] = CuPyBackend

_INSTANTIATED_BACKENDS: Dict[str, bk] = dict()


[docs]def get_backend(backend: Union[Text, bk]) -> bk: """ Get the `tc.backend` object. :param backend: "numpy", "tensorflow", "jax", "pytorch" :type backend: Union[Text, tnbackend] :raises ValueError: Backend doesn't exist for `backend` argument. :return: The `tc.backend` object that with all registered universal functions. :rtype: backend object """ if isinstance(backend, tnbackend): return backend backend = backend.lower() if backend not in _BACKENDS: raise ValueError("Backend '{}' does not exist".format(backend)) if backend in _INSTANTIATED_BACKENDS: return _INSTANTIATED_BACKENDS[backend] _INSTANTIATED_BACKENDS[backend] = _BACKENDS[backend]() return _INSTANTIATED_BACKENDS[backend]