Skip to content

Commit

Permalink
feat: finished jitting the easily jittable functions
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Feb 21, 2025
1 parent b7f1e73 commit 5ed0933
Showing 1 changed file with 191 additions and 35 deletions.
226 changes: 191 additions & 35 deletions src/qibojit/custom_operators/quantum_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,6 @@

ENGINE = qinfo.ENGINE # this should be numpy

SIGNATURES = {
# "_vectorization_row": (
# ["c16[:,::1](c16[:,::1], i8)", "c16[:,::1](c16[:,:,::1], i8)"],
# {"parallel": True, "cache": True},
# ),
# "_unvectorization_row": (
# ["c16[:,:,::1](c16[:,::1], i8)", "c16[:,:,::1](c16[:,:,::1], i8)"],
# {"parallel": True, "cache": True},
# ),
# "_random_hermitian": ("c16[:,:](i8)", {"parallel": True, "cache": True})
# "_vectorize_pauli_basis_row": ("c16[:,::1](i8, c16[:,::1], c16[:,::1], c16[:,::1], c16[:,::1], f8)", {"parallel": True, "cache": True}),
# "_vectorize_pauli_basis_column": ("c16[:,::1](i8, c16[:,::1], c16[:,::1], c16[:,::1], c16[:,::1], f8)", {"parallel": True, "cache": True}),
}


def jit_function(signature, function, **kwargs):
if isinstance(function, str):
function = getattr(qinfo, function)
return njit(signature, **kwargs)(function)


class QinfoNumba:
pass


QINFO = QinfoNumba()

for function, signature in SIGNATURES.items():
print(function)
jitted = jit_function(signature[0], function, **signature[1])
globals()[function] = jitted
setattr(QINFO, function, jitted)


@njit("c16[:,:,::1](i8, c16[:,:,:,::1])", parallel=True, cache=True)
def _pauli_basis_inner(
Expand All @@ -49,9 +16,9 @@ def _pauli_basis_inner(
):
dim = 2**nqubits
basis = ENGINE.empty((len(prod), dim, dim), dtype=ENGINE.complex128)
for i in prange(len(prod)): # pylint: disable=not-an-iterable
for i in prange(len(prod)):
elem = prod[i][0]
for j in prange(1, len(prod[i])): # pylint: disable=not-an-iterable
for j in prange(1, len(prod[i])):
elem = ENGINE.kron(elem, prod[i][j])
basis[i] = elem
return basis

Check warning on line 24 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L17-L24

Added lines #L17 - L24 were not covered by tests
Expand Down Expand Up @@ -131,6 +98,8 @@ def _vectorization_column(state, dim):


# dynamic tuple creation is not possible in numba
# this might be jittable if we passed the shape
# dim = (2,) * 2 * nqubits as inputs
@njit
def _vectorization_system(state, dim=0):
nqubits = int(ENGINE.log2(state.shape[-1]))
Expand Down Expand Up @@ -163,6 +132,19 @@ def _unvectorization_column(state, dim):
return state.T

Check warning on line 132 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L132

Added line #L132 was not covered by tests


@njit("c16[:,:](c16[:,:], i8, i8)", parallel=True, cache=True)
def _reshuffling(super_op, ax1: int, ax2: int):
dim = int(ENGINE.sqrt(super_op.shape[0]))
super_op = ENGINE.reshape(ENGINE.ascontiguousarray(super_op), (dim, dim, dim, dim))
axes = ENGINE.arange(super_op.ndim)
tmp = axes[ax1]
axes[ax1] = axes[ax2]
axes[ax2] = tmp

Check warning on line 142 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L137-L142

Added lines #L137 - L142 were not covered by tests
# axes[[ax1, ax2]] = axes[[ax2, ax1]]
super_op = ENGINE.transpose(super_op, to_fixed_tuple(axes, 4))
return ENGINE.reshape(ENGINE.ascontiguousarray(super_op), (dim**2, dim**2))

Check warning on line 145 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L144-L145

Added lines #L144 - L145 were not covered by tests


@njit(
[
nbt.complex128[:](
Expand Down Expand Up @@ -290,6 +272,30 @@ def _pauli_to_comp_basis_sparse_row(
return ENGINE.ascontiguousarray(unitary).reshape(unitary.shape[0], -1), nonzero[1]

Check warning on line 272 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L269-L272

Added lines #L269 - L272 were not covered by tests


@njit(
nbt.Tuple((nbt.complex128[:, ::1], nbt.int64[:]))(
nbt.int64,
nbt.complex128[:, ::1],
nbt.complex128[:, ::1],
nbt.complex128[:, ::1],
nbt.complex128[:, ::1],
nbt.float64,
),
parallel=True,
cache=True,
)
def _pauli_to_comp_basis_sparse_column(
nqubits: int, pauli_0, pauli_1, pauli_2, pauli_3, normalization: float = 1.0
):
unitary = _vectorize_pauli_basis_column(

Check warning on line 290 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L290

Added line #L290 was not covered by tests
nqubits, pauli_0, pauli_1, pauli_2, pauli_3, normalization
)
unitary = unitary.T
nonzero = ENGINE.nonzero(unitary)
unitary = _array_at_2d_indices(unitary, nonzero)
return ENGINE.ascontiguousarray(unitary).reshape(unitary.shape[0], -1), nonzero[1]

Check warning on line 296 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L293-L296

Added lines #L293 - L296 were not covered by tests


@njit(
nbt.Tuple((nbt.complex128[:, :], nbt.complex128[:, :], nbt.float64[:, :, ::1]))(
nbt.complex128[:, :]
Expand Down Expand Up @@ -445,6 +451,9 @@ def expm(A):
"""


# if we can implement the expm in pure numba
# we will be able to completely jit random unitary
# and the other functions that depend on it
def _random_unitary(dims: int):
H = _random_hermitian(dims)
return expm(-1.0j * H / 2)

Check warning on line 459 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L458-L459

Added lines #L458 - L459 were not covered by tests
Expand All @@ -461,6 +470,9 @@ def _random_density_matrix_bures_inner(
return state / ENGINE.trace(state)

Check warning on line 470 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L466-L470

Added lines #L466 - L470 were not covered by tests


# not entirely jittable because depends on random unitary


def _random_density_matrix_bures(dims: int, rank: int, mean: float, stddev: float):
unitary = _random_unitary(dims)
return _random_density_matrix_bures_inner(unitary, dims, rank, mean, stddev)

Check warning on line 478 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L477-L478

Added lines #L477 - L478 were not covered by tests
Expand Down Expand Up @@ -612,6 +624,36 @@ def _super_op_from_bcsz_measure_column(dims: int, rank: int):
return operator @ super_op @ operator

Check warning on line 624 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L622-L624

Added lines #L622 - L624 were not covered by tests


@njit("c16[:,:](i8)", parallel=True, cache=True)
def _super_op_from_haar_measure_row(dims: int):
super_op = _random_unitary_haar(dims)
super_op = _vectorization_row(super_op, dims)
return ENGINE.outer(super_op, ENGINE.conj(super_op))

Check warning on line 631 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L629-L631

Added lines #L629 - L631 were not covered by tests


@njit("c16[:,:](i8)", parallel=True, cache=True)
def _super_op_from_haar_measure_column(dims: int):
super_op = _random_unitary_haar(dims)
super_op = _vectorization_column(super_op, dims)
return ENGINE.outer(super_op, ENGINE.conj(super_op))

Check warning on line 638 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L636-L638

Added lines #L636 - L638 were not covered by tests


# these two can't be jitted (at least globally) because of
# random unitary


def _super_op_from_hermitian_measure_row(dims: int):
super_op = _random_unitary(dims)
super_op = _vectorization_row(super_op, dims)
return ENGINE.outer(super_op, ENGINE.conj(super_op))

Check warning on line 648 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L646-L648

Added lines #L646 - L648 were not covered by tests


def _super_op_from_hermitian_measure_column(dims: int):
super_op = _random_unitary(dims)
super_op = _vectorization_column(super_op, dims)
return ENGINE.outer(super_op, ENGINE.conj(super_op))

Check warning on line 654 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L652-L654

Added lines #L652 - L654 were not covered by tests


@njit("c16[:,:](c16[:,:,::1], c16[:], i8)", parallel=True, cache=True)
def _kraus_to_stinespring(kraus_ops, initial_state_env, dim_env: int):
alphas = ENGINE.zeros((dim_env, dim_env, dim_env), dtype=initial_state_env.dtype)
Expand Down Expand Up @@ -645,18 +687,97 @@ def _stinespring_to_kraus(stinespring, initial_state_env, dim: int, dim_env: int
return kraus.reshape(dim, dim_env, dim_env)

Check warning on line 687 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L673-L687

Added lines #L673 - L687 were not covered by tests


@njit("c16[:,:](c16[:,:])", parallel=True, cache=True)
def _to_choi_row(channel):
channel = _vectorization_row(channel, channel.shape[-1])
return ENGINE.outer(channel, ENGINE.conj(channel))

Check warning on line 693 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L692-L693

Added lines #L692 - L693 were not covered by tests


@njit("c16[:,:](c16[:,:])", parallel=True, cache=True)
def _to_choi_column(channel):
channel = _vectorization_column(channel, channel.shape[-1])
return ENGINE.outer(channel, ENGINE.conj(channel))

Check warning on line 699 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L698-L699

Added lines #L698 - L699 were not covered by tests


@njit("c16[:,:](c16[:,:])", parallel=True, cache=True)
def _to_liouville_row(channel):
channel = _to_choi_row(channel)
return _reshuffling(channel, 1, 2)

Check warning on line 705 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L704-L705

Added lines #L704 - L705 were not covered by tests


@njit("c16[:,:](c16[:,:])", parallel=True, cache=True)
def _to_liouville_column(channel):
channel = _to_choi_column(channel)
return _reshuffling(channel, 0, 3)

Check warning on line 711 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L710-L711

Added lines #L710 - L711 were not covered by tests


@njit(
nbt.Tuple((nbt.complex128[:, :, :], nbt.float64[:]))(
nbt.float64[:], nbt.complex128[:, :], nbt.float64
),
parallel=True,
cache=True,
)
def _choi_to_kraus_cp_row(eigenvalues, eigenvectors, precision: float):
eigv_gt_tol = ENGINE.abs(eigenvalues) > precision
coefficients = ENGINE.sqrt(eigenvalues[eigv_gt_tol])
eigenvectors = eigenvectors[eigv_gt_tol]
dim = int(ENGINE.sqrt(eigenvectors.shape[-1]))
kraus_ops = coefficients.reshape(-1, 1, 1) * _unvectorization_row(eigenvectors, dim)
return kraus_ops, coefficients

Check warning on line 727 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L722-L727

Added lines #L722 - L727 were not covered by tests


@njit(
nbt.Tuple((nbt.complex128[:, :, :], nbt.float64[:]))(
nbt.float64[:], nbt.complex128[:, :], nbt.float64
),
parallel=True,
cache=True,
)
def _choi_to_kraus_cp_column(eigenvalues, eigenvectors, precision: float):
eigv_gt_tol = ENGINE.abs(eigenvalues) > precision
coefficients = ENGINE.sqrt(eigenvalues[eigv_gt_tol])
eigenvectors = eigenvectors[eigv_gt_tol]
dim = int(ENGINE.sqrt(eigenvectors.shape[-1]))
kraus_ops = coefficients.reshape(-1, 1, 1) * _unvectorization_column(

Check warning on line 742 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L738-L742

Added lines #L738 - L742 were not covered by tests
eigenvectors, dim
)
return kraus_ops, coefficients

Check warning on line 745 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L745

Added line #L745 was not covered by tests


@njit("c16[:,:](c16[:,:,:])", parallel=True, cache=True)
def _kraus_to_choi_row(kraus_ops):
kraus_ops = _vectorization_row(kraus_ops, kraus_ops.shape[-1])
return kraus_ops.T @ ENGINE.conj(kraus_ops)

Check warning on line 751 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L750-L751

Added lines #L750 - L751 were not covered by tests


@njit("c16[:,:](c16[:,:,:])", parallel=True, cache=True)
def _kraus_to_choi_column(kraus_ops):
kraus_ops = _vectorization_column(kraus_ops, kraus_ops.shape[-1])
return kraus_ops.T @ ENGINE.conj(kraus_ops)

Check warning on line 757 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L756-L757

Added lines #L756 - L757 were not covered by tests


class QinfoNumba:
pass


QINFO = QinfoNumba()


for function in (
_pauli_basis,
_vectorization_row,
_vectorization_column,
_unvectorization_row,
_unvectorization_column,
_reshuffling,
_post_sparse_pauli_basis_vectorization,
_vectorize_pauli_basis_row,
_vectorize_pauli_basis_column,
_vectorize_sparse_pauli_basis_row,
_vectorize_sparse_pauli_basis_column,
_pauli_to_comp_basis_sparse_row,
_pauli_to_comp_basis_sparse_column,
_choi_to_kraus_row,
_choi_to_kraus_column,
_random_statevector,
Expand All @@ -672,7 +793,42 @@ def _stinespring_to_kraus(stinespring, initial_state_env, dim: int, dim_env: int
_gamma_delta_matrices,
_super_op_from_bcsz_measure_row,
_super_op_from_bcsz_measure_column,
_super_op_from_haar_measure_row,
_super_op_from_haar_measure_column,
_super_op_from_hermitian_measure_row,
_super_op_from_hermitian_measure_column,
_kraus_to_stinespring,
_stinespring_to_kraus,
_to_choi_row,
_to_choi_column,
_to_liouville_row,
_to_liouville_column,
_choi_to_kraus_cp_row,
_choi_to_kraus_cp_column,
_kraus_to_choi_row,
_kraus_to_choi_column,
):
setattr(QINFO, function.__name__, function)


# it would be quite cool and spare us a lot of code repetition if
# we could make a recursive approach like the one below working

SIGNATURES = {
# "_random_hermitian": ("c16[:,:](i8)", {"parallel": True, "cache": True})
# "_vectorize_pauli_basis_row": ("c16[:,::1](i8, c16[:,::1], c16[:,::1], c16[:,::1], c16[:,::1], f8)", {"parallel": True, "cache": True}),
# "_vectorize_pauli_basis_column": ("c16[:,::1](i8, c16[:,::1], c16[:,::1], c16[:,::1], c16[:,::1], f8)", {"parallel": True, "cache": True}),
}


def jit_function(signature, function, **kwargs):
if isinstance(function, str):
function = getattr(qinfo, function)
return njit(signature, **kwargs)(function)

Check warning on line 827 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L825-L827

Added lines #L825 - L827 were not covered by tests


for function, signature in SIGNATURES.items():
print(function)
jitted = jit_function(signature[0], function, **signature[1])
globals()[function] = jitted
setattr(QINFO, function, jitted)

Check warning on line 834 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L831-L834

Added lines #L831 - L834 were not covered by tests

0 comments on commit 5ed0933

Please sign in to comment.