Skip to content

Commit

Permalink
feat: jitted more random ensamble functions
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Feb 20, 2025
1 parent 8993e63 commit 1c3afdc
Showing 1 changed file with 288 additions and 16 deletions.
304 changes: 288 additions & 16 deletions src/qibojit/custom_operators/quantum_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
import numba.types as nbt
import numpy as np
import qibo.quantum_info.quantum_info as qinfo
from numba import njit, prange
from numba import njit, prange, void
from numba.np.unsafe.ndarray import to_fixed_tuple
from scipy.linalg import expm

ENGINE = qinfo.ENGINE

SIGNATURES = {
"_vectorization_row": (
["c16[:,::1](c16[:,::1], i8)", "c16[:,::1](c16[:,:,::1], i8)"],
{"parallel": True, "cache": True},
),
"_unvectorization_row": (
["c16[:,:,::1](c16[:,::1], i8)", "c16[:,:,::1](c16[:,:,::1], i8)"],
{"parallel": True, "cache": True},
),
# "_vectorization_row": (
# ["c16[:,::1](c16[:,::1], i8)", "c16[:,::1](c16[:,:,::1], i8)"],
# {"parallel": True, "cache": True},
# ),
# "_unvectorization_row": (
# ["c16[:,:,::1](c16[:,::1], i8)", "c16[:,:,::1](c16[:,:,::1], i8)"],
# {"parallel": True, "cache": True},
# ),
# "_random_hermitian": ("c16[:,:](i8)", {"parallel": True, "cache": True})
# "_vectorize_pauli_basis_row": ("c16[:,::1](i8, c16[:,::1], c16[:,::1], c16[:,::1], c16[:,::1], f8)", {"parallel": True, "cache": True}),
# "_vectorize_pauli_basis_column": ("c16[:,::1](i8, c16[:,::1], c16[:,::1], c16[:,::1], c16[:,::1], f8)", {"parallel": True, "cache": True}),
}
Expand Down Expand Up @@ -114,6 +116,18 @@ def numba_transpose(array, axes):
return array

Check warning on line 116 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L116

Added line #L116 was not covered by tests


@njit(
["c16[:,::1](c16[:,:], i8)", "c16[:,::1](c16[:,:,:], i8)"],
parallel=True,
cache=True,
)
def _vectorization_row(state, dim: int):
return ENGINE.reshape(ENGINE.ascontiguousarray(state), (-1, dim**2))

Check warning on line 125 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L125

Added line #L125 was not covered by tests


setattr(QINFO, "_vectorization_row", _vectorization_row)


@njit(["c16[:,::1](c16[:,:], i8)", "c16[:,::1](c16[:,:,:], i8)"], cache=True)
def _vectorization_column(state, dim):
indices = ENGINE.arange(state.ndim)
Expand Down Expand Up @@ -141,13 +155,26 @@ def _vectorization_system(state, dim=0):
# setattr(QINFO, "_vectorization_system", _vectorization_system)


@njit(
["c16[:,:,::1](c16[:,:], i8)", "c16[:,:,::1](c16[:,:,:], i8)"],
parallel=True,
cache=True,
)
def _unvectorization_row(state, dim: int):
return ENGINE.reshape(ENGINE.ascontiguousarray(state), (state.shape[0], dim, dim))

Check warning on line 164 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L164

Added line #L164 was not covered by tests


setattr(QINFO, "_unvectorization_row", _unvectorization_row)


@njit(["c16[:,:,:](c16[:,:], i8)", "c16[:,:,:](c16[:,:,:], i8)"], cache=True)
def _unvectorization_column(state, dim):
axes = ENGINE.arange(state.ndim)[::-1]
# axes = ENGINE.arange(state.ndim)[::-1]
last_dim = state.shape[0]
state = numba_transpose(state, axes)
state = state.T # numba_transpose(state, axes)
state = ENGINE.ascontiguousarray(state).reshape(dim, dim, last_dim)

Check warning on line 175 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L173-L175

Added lines #L173 - L175 were not covered by tests
return numba_transpose(state, ENGINE.array([2, 1, 0], dtype=ENGINE.int64))
# return numba_transpose(state, ENGINE.array([2, 1, 0], dtype=ENGINE.int64))
return state.T

Check warning on line 177 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L177

Added line #L177 was not covered by tests


setattr(QINFO, "_unvectorization_column", _unvectorization_column)
Expand All @@ -159,6 +186,7 @@ def _unvectorization_column(state, dim):
nbt.complex128[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:]))
),
nbt.float64[:](nbt.float64[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:]))),
nbt.int64[:](nbt.int64[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:]))),
],
parallel=True,
cache=True,
Expand Down Expand Up @@ -293,7 +321,8 @@ def _pauli_to_comp_basis_sparse_row(
unitary = _vectorize_pauli_basis_row(

Check warning on line 321 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L321

Added line #L321 was not covered by tests
nqubits, pauli_0, pauli_1, pauli_2, pauli_3, normalization
)
unitary = numba_transpose(unitary, ENGINE.arange(unitary.ndim)[::-1])
# unitary = numba_transpose(unitary, ENGINE.arange(unitary.ndim)[::-1])
unitary = unitary.T
nonzero = ENGINE.nonzero(unitary)
unitary = _array_at_2d_indices(unitary, nonzero)
return ENGINE.ascontiguousarray(unitary).reshape(unitary.shape[0], -1), nonzero[1]

Check warning on line 328 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L325-L328

Added lines #L325 - L328 were not covered by tests
Expand All @@ -312,7 +341,8 @@ def _pauli_to_comp_basis_sparse_row(
def _choi_to_kraus_preamble(choi_super_op):
U, coefficients, V = ENGINE.linalg.svd(choi_super_op)

Check warning on line 342 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L342

Added line #L342 was not covered by tests
# U = np.ascontiguousarray(U)
U = numba_transpose(U, ENGINE.arange(U.ndim)[::-1])
# U = numba_transpose(U, ENGINE.arange(U.ndim)[::-1])
U = U.T
coefficients = ENGINE.sqrt(coefficients)
V = ENGINE.conj(V)
coefficients = coefficients.reshape(U.shape[0], 1, 1)

Check warning on line 348 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L345-L348

Added lines #L345 - L348 were not covered by tests
Expand All @@ -337,8 +367,8 @@ def _kraus_operators(kraus_left, kraus_right):
def _choi_to_kraus_row(choi_super_op):
U, V, coefficients = _choi_to_kraus_preamble(choi_super_op)
dim = int(np.sqrt(U.shape[-1]))
kraus_left = coefficients * _unvectorization_row(ENGINE.ascontiguousarray(U), dim)
kraus_right = coefficients * _unvectorization_row(ENGINE.ascontiguousarray(V), dim)
kraus_left = coefficients * _unvectorization_row(U, dim)
kraus_right = coefficients * _unvectorization_row(V, dim)
kraus_ops = _kraus_operators(kraus_left, kraus_right)
return kraus_ops, coefficients

Check warning on line 373 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L368-L373

Added lines #L368 - L373 were not covered by tests

Expand All @@ -364,6 +394,16 @@ def _choi_to_kraus_column(choi_super_op):
setattr(QINFO, "_choi_to_kraus_column", _choi_to_kraus_column)


@njit("c16[:](i8)", parallel=True, cache=True)
def _random_statevector(dims: int):
state = ENGINE.random.standard_normal(dims)
state = state + 1.0j * ENGINE.random.standard_normal(dims)
return state / ENGINE.linalg.norm(state)


setattr(QINFO, "_random_statevector", _random_statevector)


@njit("c16[:,:](i8, i8, f8, f8)", parallel=True, cache=True)
def _random_gaussian_matrix(dims: int, rank: int, mean: float, stddev: float):
matrix = ENGINE.empty((dims, rank), dtype=ENGINE.complex128)
Expand All @@ -378,6 +418,238 @@ def _random_gaussian_matrix(dims: int, rank: int, mean: float, stddev: float):
setattr(QINFO, "_random_gaussian_matrix", _random_gaussian_matrix)


@njit("c16[:,:](i8)", parallel=True, cache=True)
def _random_density_matrix_pure(dims: int):
state = _random_statevector(dims)
return ENGINE.outer(state, ENGINE.conj(state).T)

Check warning on line 424 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L423-L424

Added lines #L423 - L424 were not covered by tests


setattr(QINFO, "_random_density_matrix_pure", _random_density_matrix_pure)


@njit("c16[:,:](i8, i8, f8, f8)", parallel=True, cache=True)
def _random_density_matrix_hs_ginibre(dims: int, rank: int, mean: float, stddev: float):
state = _random_gaussian_matrix(dims, rank, mean, stddev)
state = state @ ENGINE.transpose(ENGINE.conj(state), (1, 0))
return state / ENGINE.trace(state)


setattr(QINFO, "_random_density_matrix_hs_ginibre", _random_density_matrix_hs_ginibre)


@njit("c16[:,:](i8)", parallel=True, cache=True)
def _random_hermitian(dims: int):
matrix = _random_gaussian_matrix(dims, dims, 0.0, 1.0)
return (matrix + ENGINE.conj(matrix).T) / 2

Check warning on line 443 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L442-L443

Added lines #L442 - L443 were not covered by tests


setattr(QINFO, "_random_hermitian", _random_hermitian)


@njit("c16[:,:](i8)", parallel=True, cache=True)
def _random_hermitian_semidefinite(dims: int):
matrix = _random_gaussian_matrix(dims, dims, 0.0, 1.0)
return ENGINE.conj(matrix).T @ matrix

Check warning on line 452 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L451-L452

Added lines #L451 - L452 were not covered by tests


setattr(QINFO, "_random_hermitian_semidefinite", _random_hermitian_semidefinite)


@njit("c16[:,:](i8)", parallel=True, cache=True)
def _random_unitary_haar(dims: int):
matrix = _random_gaussian_matrix(dims, dims, 0.0, 1.0)
Q, R = ENGINE.linalg.qr(matrix)
D = ENGINE.diag(R)
D = D / ENGINE.abs(D)
R = ENGINE.diag(D)
return ENGINE.ascontiguousarray(Q) @ R

Check warning on line 465 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L460-L465

Added lines #L460 - L465 were not covered by tests


setattr(QINFO, "_random_unitary_haar", _random_unitary_haar)

"""
# double check whether this is correct
#@njit
def expm(A):
'''Compute expm(A) using the Padé approximant and scaling/squaring.'''
# Constants for Padé approximant
pade_coeffs = np.array([
64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
1187353796428800.0, 129060195264000.0, 10559470521600.0,
670442572800.0, 33522128640.0, 1323241920.0, 40840800.0,
960960.0, 16380.0, 182.0, 1.0
])
n = A.shape[0]
A_norm = np.max(np.sum(np.abs(A), axis=1)) # Compute norm estimate
# Scaling step
s = max(0, int(np.log2(A_norm)) - 4)
A_scaled = A / (2 ** s)
# Compute Padé approximant
X = A_scaled @ A_scaled
U = ENGINE.eye(n, dtype=A.dtype) * pade_coeffs[1]
V = ENGINE.eye(n, dtype=A.dtype) * pade_coeffs[0]
for i in range(2, len(pade_coeffs)):
U = X @ U + pade_coeffs[i] * ENGINE.eye(n, dtype=A.dtype)
V = X @ V + pade_coeffs[i - 1] * ENGINE.eye(n, dtype=A.dtype)
U = A_scaled @ U
P = V + U
Q = V - U
breakpoint()
# Solve (I - U)⁻¹ * (I + U)
F = ENGINE.linalg.solve(Q, P)
# Squaring step
for _ in range(s):
F = F @ F
return F
"""


def _random_unitary(dims: int):
H = _random_hermitian(dims)
return expm(-1.0j * H / 2)

Check warning on line 517 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L516-L517

Added lines #L516 - L517 were not covered by tests


setattr(QINFO, "_random_unitary", _random_unitary)


@njit("c16[:,:](c16[:,:], i8, i8, f8, f8)", parallel=True, cache=True)
def _random_density_matrix_bures_inner(
unitary, dims: int, rank: int, mean: float, stddev: float
):
state = ENGINE.eye(dims, dtype=unitary.dtype)
state += unitary
state = state @ _random_gaussian_matrix(dims, rank, mean, stddev)
state = state @ ENGINE.transpose(ENGINE.conj(state), (1, 0))
return state / ENGINE.trace(state)

Check warning on line 531 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L527-L531

Added lines #L527 - L531 were not covered by tests


def _random_density_matrix_bures(dims: int, rank: int, mean: float, stddev: float):
unitary = _random_unitary(dims)
return _random_density_matrix_bures_inner(unitary, dims, rank, mean, stddev)

Check warning on line 536 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L535-L536

Added lines #L535 - L536 were not covered by tests


setattr(QINFO, "_random_density_matrix_bures", _random_density_matrix_bures)


@njit(nbt.Tuple((nbt.int64[:], nbt.int64[:]))(nbt.int64), parallel=True, cache=True)
def _sample_from_quantum_mallows_distribution(nqubits: int):
exponents = ENGINE.arange(nqubits, 0, -1, dtype=ENGINE.int64)
powers = 4**exponents
powers[powers == 0] = ENGINE.iinfo(ENGINE.int64).max
r = ENGINE.random.uniform(0, 1, size=nqubits)
indexes = (-1) * ENGINE.ceil(ENGINE.log2(r + (1 - r) / powers)).astype(ENGINE.int64)
idx_le_exp = indexes < exponents
hadamards = idx_le_exp.astype(ENGINE.int64)
idx_gt_exp = idx_le_exp ^ True
indexes[idx_gt_exp] = 2 * exponents[idx_gt_exp] - indexes[idx_gt_exp] - 1
mute_index = list(range(nqubits))
permutations = ENGINE.zeros(nqubits, dtype=ENGINE.int64)
for l, index in enumerate(indexes):
permutations[l] = mute_index[index]
del mute_index[index]
return hadamards, permutations

Check warning on line 558 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L544-L558

Added lines #L544 - L558 were not covered by tests


setattr(
QINFO,
"_sample_from_quantum_mallows_distribution",
_sample_from_quantum_mallows_distribution,
)


@njit(
[
void(
nbt.complex128[:, :],
nbt.Tuple((nbt.int64[:], nbt.int64[:])),
nbt.complex128[:],
),
void(
nbt.float64[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:])), nbt.float64[:]
),
void(nbt.int64[:, :], nbt.Tuple((nbt.int64[:], nbt.int64[:])), nbt.int64[:]),
],
parallel=True,
cache=True,
)
def _set_array_at_2d_indices(array, indices, values):
for i in prange(len(indices[0])):
array[indices[0][i], indices[1][i]] = values[i]

Check warning on line 585 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L584-L585

Added lines #L584 - L585 were not covered by tests


@njit(
nbt.Tuple((nbt.int64[:, :], nbt.int64[:, :], nbt.int64[:, :], nbt.int64[:, :]))(
nbt.int64, nbt.int64[:], nbt.int64[:]
),
parallel=True,
cache=True,
)
def _gamma_delta_matrices(nqubits: int, hadamards, permutations):
delta_matrix = ENGINE.eye(nqubits, dtype=ENGINE.int64)
delta_matrix_prime = ENGINE.copy(delta_matrix)

Check warning on line 597 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L596-L597

Added lines #L596 - L597 were not covered by tests

gamma_matrix_prime = ENGINE.random.randint(0, 2, size=nqubits)
gamma_matrix_prime = ENGINE.diag(gamma_matrix_prime)

Check warning on line 600 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L599-L600

Added lines #L599 - L600 were not covered by tests

gamma_matrix = ENGINE.random.randint(0, 2, size=nqubits)
gamma_matrix = hadamards * gamma_matrix
gamma_matrix = ENGINE.diag(gamma_matrix)

Check warning on line 604 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L602-L604

Added lines #L602 - L604 were not covered by tests

tril_indices = ENGINE.tril_indices(nqubits, k=-1)
_set_array_at_2d_indices(

Check warning on line 607 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L606-L607

Added lines #L606 - L607 were not covered by tests
delta_matrix_prime,
tril_indices,
ENGINE.random.randint(0, 2, size=len(tril_indices[0])),
)

_set_array_at_2d_indices(

Check warning on line 613 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L613

Added line #L613 was not covered by tests
gamma_matrix_prime,
tril_indices,
ENGINE.random.randint(0, 2, size=len(tril_indices[0])),
)

triu_indices = ENGINE.triu_indices(nqubits, k=1)
_set_array_at_2d_indices(

Check warning on line 620 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L619-L620

Added lines #L619 - L620 were not covered by tests
gamma_matrix_prime,
triu_indices,
_array_at_2d_indices(gamma_matrix_prime, tril_indices),
)

p_col_gt_row = permutations[triu_indices[1]] > permutations[triu_indices[0]]
p_col_neq_row = permutations[triu_indices[1]] != permutations[triu_indices[0]]
p_col_le_row = p_col_gt_row ^ True
h_row_eq_0 = hadamards[triu_indices[0]] == 0
h_col_eq_0 = hadamards[triu_indices[1]] == 0

Check warning on line 630 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L626-L630

Added lines #L626 - L630 were not covered by tests

idx = (h_row_eq_0 * h_col_eq_0 ^ True) * p_col_neq_row
elements = ENGINE.random.randint(0, 2, size=len(idx.nonzero()[0]))
_set_array_at_2d_indices(

Check warning on line 634 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L632-L634

Added lines #L632 - L634 were not covered by tests
gamma_matrix, (triu_indices[0][idx], triu_indices[1][idx]), elements
)
_set_array_at_2d_indices(

Check warning on line 637 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L637

Added line #L637 was not covered by tests
gamma_matrix, (triu_indices[1][idx], triu_indices[0][idx]), elements
)

idx = p_col_gt_row | (p_col_le_row * h_row_eq_0 * h_col_eq_0)
elements = ENGINE.random.randint(0, 2, size=len(idx.nonzero()[0]))
_set_array_at_2d_indices(

Check warning on line 643 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L641-L643

Added lines #L641 - L643 were not covered by tests
delta_matrix, (triu_indices[1][idx], triu_indices[0][idx]), elements
)

return gamma_matrix, gamma_matrix_prime, delta_matrix, delta_matrix_prime

Check warning on line 647 in src/qibojit/custom_operators/quantum_info.py

View check run for this annotation

Codecov / codecov/patch

src/qibojit/custom_operators/quantum_info.py#L647

Added line #L647 was not covered by tests


setattr(QINFO, "_gamma_delta_matrices", _gamma_delta_matrices)


"""
@njit(
[
Expand Down

0 comments on commit 1c3afdc

Please sign in to comment.