Skip to content

Commit

Permalink
feat: jitting some super op transf
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Feb 19, 2025
1 parent 9d1a7fe commit 6c7fecc
Showing 1 changed file with 90 additions and 5 deletions.
95 changes: 90 additions & 5 deletions src/qibojit/custom_operators/quantum_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,15 @@ def _vectorization_system(state, dim=0):
setattr(QINFO, "_vectorization_column", _vectorization_column)
# setattr(QINFO, "_vectorization_system", _vectorization_system)

"""
@njit

@njit(["c16[:,:,::1](c16[:,::1], i8)", "c16[:,:,::1](c16[:,:,::1], i8)"], cache=True)
def _unvectorization_column(state, dim):
state = ENGINE.reshape(state, (state.shape[0], dim, dim))
return np.asfortranarray(state)
axes = ENGINE.arange(state.ndim)[::-1]
state = numba_transpose(state, axes).reshape(dim, dim, state.shape[0])
return numba_transpose(state, ENGINE.array([2, 1, 0], dtype=ENGINE.int64))


setattr(QINFO, "_unvectorization_column", _unvectorization_column)
"""


@njit(
Expand Down Expand Up @@ -299,3 +300,87 @@ def _pauli_to_comp_basis_sparse_row(


setattr(QINFO, "_pauli_to_comp_basis_sparse_row", _pauli_to_comp_basis_sparse_row)


@njit(
nbt.Tuple((nbt.complex128[:, ::1], nbt.complex128[:, ::1], nbt.float64[:, :, ::1]))(
nbt.complex128[:, ::1]
),
parallel=True,
cache=True,
)
def _choi_to_kraus_preamble(choi_super_op):
U, coefficients, V = ENGINE.linalg.svd(choi_super_op)
U = np.ascontiguousarray(U)
U = numba_transpose(U, ENGINE.arange(U.ndim)[::-1])
coefficients = ENGINE.sqrt(coefficients)
V = ENGINE.conj(V)
coefficients = coefficients.reshape(U.shape[0], 1, 1)
V = np.ascontiguousarray(V)
return U, V, coefficients


@njit("c16[:,:,:,:](c16[:,:,:], c16[:,:,:])", parallel=True, cache=True)
def _kraus_operators(kraus_left, kraus_right):
kraus_ops = ENGINE.empty((2,) + kraus_left.shape, dtype=kraus_left.dtype)
kraus_ops[0] = kraus_left
kraus_ops[1] = kraus_right
return kraus_ops


@njit(
nbt.Tuple((nbt.complex128[:, :, :, :], nbt.float64[:, :, ::1]))(
nbt.complex128[:, ::1]
),
cache=True,
)
def _choi_to_kraus_row(choi_super_op):
U, V, coefficients = _choi_to_kraus_preamble(choi_super_op)
dim = int(np.sqrt(U.shape[-1]))
kraus_left = coefficients * _unvectorization_row(U, dim)
kraus_right = coefficients * _unvectorization_row(V, dim)
kraus_ops = _kraus_operators(kraus_left, kraus_right)
return kraus_ops, coefficients


setattr(QINFO, "_choi_to_kraus_row", _choi_to_kraus_row)

# TODO: choi to kraus column

"""
@njit(
[
nbt.complex128[::1](
nbt.complex128[:, ::1], nbt.Tuple((nbt.int64[::1], nbt.int64[::1]))
),
nbt.float64[::1](
nbt.float64[:, ::1], nbt.Tuple((nbt.int64[::1], nbt.int64[::1]))
),
],
parallel=True,
cache=True,
)
def _set_array_at_2d_indices(array, indices, values):
empty = ENGINE.empty(indices[0].shape, dtype=array.dtype)
for i in prange(len(indices[0])):
empty[i] = array[indices[0][i], indices[1][i]]
return empty
"""

"""
def _kraus_to_stinespring(
kraus_ops, initial_state_env, dim_env: int
):
alphas = ENGINE.zeros((dim_env, dim_env, dim_env), dtype=complex)
#idx = ENGINE.arange(dim_env)
alphas[range(dim_env), range(dim_env)] = initial_state_env
# batched kron product
prod = 0.
for i in prange(len(kraus_ops)):
prod += ENGINE.kron(kraus_ops[i], alphas[i])
return prod.reshape(
2 * (kraus_ops.shape[1] * alphas.shape[1],)
)
"""

setattr(QINFO, "_kraus_to_stinespring", _kraus_to_stinespring)

0 comments on commit 6c7fecc

Please sign in to comment.