Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 198 additions & 63 deletions deeptrack/backend/mie.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@

from __future__ import annotations

import array_api_compat as apc
import numpy as np
from numpy.typing import NDArray

from ._config import config, xp
from .polynomials import (
ricbesh,
ricbesy,
Expand All @@ -50,6 +52,95 @@
)


def _iter_arrays(*values):
"""Yield array API objects from values, including nested sequences."""

for value in values:
if apc.is_array_api_obj(value):
yield value
elif isinstance(value, (list, tuple)):
yield from _iter_arrays(*value)


def _first_array(*values):
"""Return the first array API object in values, if any."""

return next(_iter_arrays(*values), None)


def _complex_dtype(*values):
"""Return the complex dtype to use for the current xp backend."""

for value in _iter_arrays(*values):
if value.dtype in (xp.float64, xp.complex128):
return xp.complex128

return xp.get_complex_dtype()


def _asarray(value, dtype=None, reference=None):
"""Convert value with xp without detaching existing arrays."""

is_current_backend_array = (
config.get_backend() == "numpy"
and apc.is_numpy_array(value)
or config.get_backend() == "torch"
and apc.is_torch_array(value)
)

if is_current_backend_array:
return xp.astype(value, dtype) if dtype is not None else value

kwargs = {}

if dtype is not None:
kwargs["dtype"] = dtype

if reference is not None:
try:
kwargs["device"] = apc.device(reference)
except TypeError:
pass

try:
return xp.asarray(value, **kwargs)
except TypeError:
kwargs.pop("device", None)
return xp.asarray(value, **kwargs)


def _asarray_vector(value, dtype=None, reference=None):
"""Convert a tensor or sequence of scalars to a one-dimensional array."""

if apc.is_array_api_obj(value):
return xp.reshape(_asarray(value, dtype, reference), (-1,))

return xp.stack(
[
xp.reshape(_asarray(element, dtype, reference), ())
for element in value
]
)


def _zeros(shape, dtype, reference=None):
"""Create a zero array on the same backend as reference."""

kwargs = {"dtype": dtype}

if reference is not None:
try:
kwargs["device"] = apc.device(reference)
except TypeError:
pass

try:
return xp.zeros(shape, **kwargs)
except TypeError:
kwargs.pop("device", None)
return xp.zeros(shape, **kwargs)


#TODO ***??*** revise coefficients - torch, docstring, unit test
def coefficients(
m: float | complex,
Expand Down Expand Up @@ -79,8 +170,19 @@ def coefficients(

"""

A = np.zeros((L,), dtype=np.complex128)
B = np.zeros((L,), dtype=np.complex128)
dtype = _complex_dtype(m, a)
reference = _first_array(m, a)
m = _asarray(m, dtype=dtype, reference=reference)
a = _asarray(a, dtype=dtype, reference=reference)

if L == 0:
return (
_zeros((0,), dtype=dtype, reference=reference),
_zeros((0,), dtype=dtype, reference=reference),
)

A = []
B = []

for l in range(1, L + 1):
Sx = ricbesj(l, a)
Expand All @@ -90,18 +192,18 @@ def coefficients(
xix = ricbesh(l, a)
dxix = dricbesh(l, a)

A[l - 1] = (
(m * Smx * dSx - Sx * dSmx)
/
A.append(
(m * Smx * dSx - Sx * dSmx)
/
(m * Smx * dxix - xix * dSmx)
)
B[l - 1] = (
(Smx * dSx - m * Sx * dSmx)
/
B.append(
(Smx * dSx - m * Sx * dSmx)
/
(Smx * dxix - m * xix * dSmx)
)

return A, B
return xp.stack(A), xp.stack(B)


#TODO ***??*** revise stratified_coefficients - torch, docstring, unit test
Expand Down Expand Up @@ -133,66 +235,86 @@ def stratified_coefficients(
including) order L.

"""
n_layers = len(a)
dtype = _complex_dtype(m, a)
reference = _first_array(m, a)
m = _asarray_vector(m, dtype=dtype, reference=reference)
a = _asarray_vector(a, dtype=dtype, reference=reference)
n_layers = a.shape[0]

if n_layers == 1:
return coefficients(m[0], a[0], L)

an = np.zeros((L,), dtype=np.complex128)
bn = np.zeros((L,), dtype=np.complex128)
if L == 0:
return (
_zeros((0,), dtype=dtype, reference=reference),
_zeros((0,), dtype=dtype, reference=reference),
)

an = []
bn = []

for n in range(L):
A = np.zeros((2 * n_layers, 2 * n_layers), dtype=np.complex128)
C = np.zeros((2 * n_layers, 2 * n_layers), dtype=np.complex128)
A_rows = []
C_rows = []
zero = _zeros((), dtype=dtype, reference=reference)

for i in range(2 * n_layers):
for j in range(2 * n_layers):
p = np.floor((j + 1) / 2).astype(np.int32)
q = np.floor((i / 2)).astype(np.int32)

if not ((p - q == 0) or (p - q == 1)):
continue

if np.mod(i, 2) == 0:
if (j < 2 * n_layers - 1) and ((j == 0) or
(np.mod(j, 2) == 1)):
A[i, j] = dricbesj(n + 1, m[p] * a[q])
elif np.mod(j, 2) == 0:
A[i, j] = dricbesy(n + 1, m[p] * a[q])
else:
A[i, j] = dricbesj(n + 1, a[q])

C[i, j] = (
m[p] * A[i, j]
if j != 2 * n_layers - 1
else A[i, j]
)
else:
if (j < 2 * n_layers - 1) and ((j == 0) or
(np.mod(j, 2) == 1)):
C[i, j] = ricbesj(n + 1, m[p] * a[q])
elif np.mod(j, 2) == 0:
C[i, j] = ricbesy(n + 1, m[p] * a[q])
p = (j + 1) // 2
q = i // 2
A_ij = zero
C_ij = zero

if (p - q == 0) or (p - q == 1):
if i % 2 == 0:
if (
j < 2 * n_layers - 1
and (j == 0 or j % 2 == 1)
):
A_ij = dricbesj(n + 1, m[p] * a[q])
elif j % 2 == 0:
A_ij = dricbesy(n + 1, m[p] * a[q])
else:
A_ij = dricbesj(n + 1, a[q])

if j != 2 * n_layers - 1:
C_ij = m[p] * A_ij
else:
C_ij = A_ij
else:
C[i, j] = ricbesj(n + 1, a[q])

A[i, j] = (
m[p] * C[i, j]
if j != 2 * n_layers - 1
else C[i, j]
)

B = A.copy()
if (
j < 2 * n_layers - 1
and (j == 0 or j % 2 == 1)
):
C_ij = ricbesj(n + 1, m[p] * a[q])
elif j % 2 == 0:
C_ij = ricbesy(n + 1, m[p] * a[q])
else:
C_ij = ricbesj(n + 1, a[q])

if j != 2 * n_layers - 1:
A_ij = m[p] * C_ij
else:
A_ij = C_ij

A_rows.append(A_ij)
C_rows.append(C_ij)

shape = (2 * n_layers, 2 * n_layers)
A = xp.reshape(xp.stack(A_rows), shape)
C = xp.reshape(xp.stack(C_rows), shape)

B = A * 1
B[-2, -1] = dricbesh(n + 1, a[-1])
B[-1, -1] = ricbesh(n + 1, a[-1])
an[n] = np.linalg.det(A) / np.linalg.det(B)
an.append(xp.linalg.det(A) / xp.linalg.det(B))

D = C.copy()
D = C * 1
D[-2, -1] = dricbesh(n + 1, a[-1])
D[-1, -1] = ricbesh(n + 1, a[-1])
bn[n] = np.linalg.det(C) / np.linalg.det(D)
bn.append(xp.linalg.det(C) / xp.linalg.det(D))

return an, bn
return xp.stack(an), xp.stack(bn)


#TODO ***??*** revise harmonics - torch, docstring, unit test
Expand Down Expand Up @@ -231,18 +353,31 @@ def harmonics(

"""

PI = np.zeros((L, *x.shape))
TAU = np.zeros((L, *x.shape))
x = _asarray(x)
reference = _first_array(x)

if L == 0:
return (
_zeros((0, *x.shape), dtype=x.dtype, reference=reference),
_zeros((0, *x.shape), dtype=x.dtype, reference=reference),
)

PI = []
TAU = []

if L >= 1:
PI.append(xp.ones_like(x))
TAU.append(x)

PI[0, :] = 1
PI[1, :] = 3 * x
TAU[0, :] = x
TAU[1, :] = 6 * x * x - 3
if L >= 2:
PI.append(3 * x)
TAU.append(6 * x * x - 3)

for i in range(3, L + 1):
PI[i - 1] = (
(2 * i - 1) / (i - 1) * x * PI[i - 2] - i / (i - 1) * PI[i - 3]
PI.append(
(2 * i - 1) / (i - 1) * x * PI[i - 2]
- i / (i - 1) * PI[i - 3]
)
TAU[i - 1] = i * x * PI[i - 1] - (i + 1) * PI[i - 2]
TAU.append(i * x * PI[i - 1] - (i + 1) * PI[i - 2])

return PI, TAU
return xp.stack(PI), xp.stack(TAU)
Loading