From 11af09c69c95fb0eb510b2010c825ee93c4c4901 Mon Sep 17 00:00:00 2001 From: edudc Date: Tue, 12 May 2026 14:33:47 +0200 Subject: [PATCH 01/13] Add torch Mie backend support --- deeptrack/backend/mie.py | 314 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 310 insertions(+), 4 deletions(-) diff --git a/deeptrack/backend/mie.py b/deeptrack/backend/mie.py index adcfde419..bf4df6b34 100644 --- a/deeptrack/backend/mie.py +++ b/deeptrack/backend/mie.py @@ -49,6 +49,300 @@ dricbesy, ) +try: + import torch + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + +def _is_torch_array(x) -> bool: + """Return whether x is, or contains, a torch tensor.""" + + if not TORCH_AVAILABLE: + return False + + if torch.is_tensor(x): + return True + + if isinstance(x, (list, tuple)): + return any(_is_torch_array(v) for v in x) + + return False + + +def _first_torch_array(*values): + """Return the first torch tensor in values, searching nested sequences.""" + + for value in values: + if torch.is_tensor(value): + return value + + if isinstance(value, (list, tuple)): + found = _first_torch_array(*value) + if found is not None: + return found + + return None + + +def _torch_complex_dtype(*values): + """Return a complex dtype compatible with the torch inputs.""" + + for value in values: + if torch.is_tensor(value): + if value.dtype in (torch.float64, torch.complex128): + return torch.complex128 + + if isinstance(value, (list, tuple)): + dtype = _torch_complex_dtype(*value) + if dtype == torch.complex128: + return dtype + + return torch.complex64 + + +def _as_torch_scalar(value, dtype, device): + """Convert value to a scalar torch tensor on device with dtype.""" + + if torch.is_tensor(value): + return value.to(dtype=dtype, device=device) + + return torch.as_tensor(value, dtype=dtype, device=device) + + +def _as_torch_vector(value, dtype, device): + """Convert a tensor or sequence of scalars to a one-dimensional tensor.""" + + if torch.is_tensor(value): + return value.to(dtype=dtype, device=device).reshape(-1) + + return torch.stack( + [ + _as_torch_scalar(element, dtype=dtype, device=device).reshape(()) + for element in value + ] + ) + + +def _ricbesj_torch(l: int, x): + """Differentiable torch Riccati-Bessel polynomial of the first kind.""" + + if l == 0: + return torch.sin(x) + + previous = torch.sin(x) + current = torch.sin(x) / x - torch.cos(x) + + for order in range(1, l): + previous, current = current, (2 * order + 1) / x * current - previous + + return current + + +def _dricbesj_torch(l: int, x): + """Differentiable torch derivative of ricbesj.""" + + return _ricbesj_torch(l - 1, x) - l / x * _ricbesj_torch(l, x) + + +def _ricbesy_torch(l: int, x): + """Differentiable torch Riccati-Bessel polynomial of the second kind.""" + + if l == 0: + return torch.cos(x) + + previous = torch.cos(x) + current = torch.cos(x) / x + torch.sin(x) + + for order in range(1, l): + previous, current = current, (2 * order + 1) / x * current - previous + + return current + + +def _dricbesy_torch(l: int, x): + """Differentiable torch derivative of ricbesy.""" + + return _ricbesy_torch(l - 1, x) - l / x * _ricbesy_torch(l, x) + + +def _ricbesh_torch(l: int, x): + """Differentiable torch Riccati-Bessel polynomial of the third kind.""" + + return _ricbesj_torch(l, x) - 1j * _ricbesy_torch(l, x) + + +def _dricbesh_torch(l: int, x): + """Differentiable torch derivative of ricbesh.""" + + return _dricbesj_torch(l, x) - 1j * _dricbesy_torch(l, x) + + +def _coefficients_torch( + m: float | complex | "torch.Tensor", + a: float | "torch.Tensor", + L: int, +) -> tuple["torch.Tensor", "torch.Tensor"]: + """Torch implementation of Mie coefficients.""" + + reference = _first_torch_array(m, a) + device = reference.device + dtype = _torch_complex_dtype(m, a) + + m = _as_torch_scalar(m, dtype=dtype, device=device) + a = _as_torch_scalar(a, dtype=dtype, device=device) + + if L == 0: + empty = torch.empty((0,), dtype=dtype, device=device) + return empty, empty.clone() + + A = [] + B = [] + + for l in range(1, L + 1): + Sx = _ricbesj_torch(l, a) + dSx = _dricbesj_torch(l, a) + Smx = _ricbesj_torch(l, m * a) + dSmx = _dricbesj_torch(l, m * a) + xix = _ricbesh_torch(l, a) + dxix = _dricbesh_torch(l, a) + + A.append( + (m * Smx * dSx - Sx * dSmx) + / (m * Smx * dxix - xix * dSmx) + ) + B.append( + (Smx * dSx - m * Sx * dSmx) + / (Smx * dxix - m * xix * dSmx) + ) + + return torch.stack(A), torch.stack(B) + + +def _stratified_coefficients_torch( + m: list[complex] | "torch.Tensor", + a: list[float] | "torch.Tensor", + L: int, +) -> tuple["torch.Tensor", "torch.Tensor"]: + """Torch implementation of stratified Mie coefficients.""" + + reference = _first_torch_array(m, a) + device = reference.device + dtype = _torch_complex_dtype(m, a) + + m = _as_torch_vector(m, dtype=dtype, device=device) + a = _as_torch_vector(a, dtype=dtype, device=device) + n_layers = a.numel() + + if n_layers == 1: + return _coefficients_torch(m[0], a[0], L) + + if L == 0: + empty = torch.empty((0,), dtype=dtype, device=device) + return empty, empty.clone() + + an = [] + bn = [] + + for n in range(L): + A_rows = [] + C_rows = [] + zero = torch.zeros((), dtype=dtype, device=device) + + for i in range(2 * n_layers): + for j in range(2 * n_layers): + p = (j + 1) // 2 + q = i // 2 + + A_ij = zero + C_ij = zero + + if (p - q == 0) or (p - q == 1): + if i % 2 == 0: + if ( + j < 2 * n_layers - 1 + and (j == 0 or j % 2 == 1) + ): + A_ij = _dricbesj_torch(n + 1, m[p] * a[q]) + elif j % 2 == 0: + A_ij = _dricbesy_torch(n + 1, m[p] * a[q]) + else: + A_ij = _dricbesj_torch(n + 1, a[q]) + + if j != 2 * n_layers - 1: + C_ij = m[p] * A_ij + else: + C_ij = A_ij + else: + if ( + j < 2 * n_layers - 1 + and (j == 0 or j % 2 == 1) + ): + C_ij = _ricbesj_torch(n + 1, m[p] * a[q]) + elif j % 2 == 0: + C_ij = _ricbesy_torch(n + 1, m[p] * a[q]) + else: + C_ij = _ricbesj_torch(n + 1, a[q]) + + if j != 2 * n_layers - 1: + A_ij = m[p] * C_ij + else: + A_ij = C_ij + + A_rows.append(A_ij) + C_rows.append(C_ij) + + A = torch.stack(A_rows).reshape(2 * n_layers, 2 * n_layers) + C = torch.stack(C_rows).reshape(2 * n_layers, 2 * n_layers) + + B = A.clone() + B[-2, -1] = _dricbesh_torch(n + 1, a[-1]) + B[-1, -1] = _ricbesh_torch(n + 1, a[-1]) + an.append(torch.linalg.det(A) / torch.linalg.det(B)) + + D = C.clone() + D[-2, -1] = _dricbesh_torch(n + 1, a[-1]) + D[-1, -1] = _ricbesh_torch(n + 1, a[-1]) + bn.append(torch.linalg.det(C) / torch.linalg.det(D)) + + return torch.stack(an), torch.stack(bn) + + +def _harmonics_torch( + x: "torch.Tensor", + L: int, +) -> tuple["torch.Tensor", "torch.Tensor"]: + """Torch implementation of Mie harmonics.""" + + PI = [] + TAU = [] + + if L == 0: + shape = (0, *x.shape) + return ( + torch.empty(shape, dtype=x.dtype, device=x.device), + torch.empty(shape, dtype=x.dtype, device=x.device), + ) + + if L >= 1: + PI.append(torch.ones_like(x)) + TAU.append(x) + + if L >= 2: + PI.append(3 * x) + TAU.append(6 * x * x - 3) + + for i in range(3, L + 1): + PI.append( + (2 * i - 1) / (i - 1) * x * PI[i - 2] + - i / (i - 1) * PI[i - 3] + ) + TAU.append(i * x * PI[i - 1] - (i + 1) * PI[i - 2]) + + return torch.stack(PI), torch.stack(TAU) + #TODO ***??*** revise coefficients - torch, docstring, unit test def coefficients( @@ -79,6 +373,9 @@ def coefficients( """ + if _is_torch_array(m) or _is_torch_array(a): + return _coefficients_torch(m, a, L) + A = np.zeros((L,), dtype=np.complex128) B = np.zeros((L,), dtype=np.complex128) @@ -133,6 +430,9 @@ def stratified_coefficients( including) order L. """ + if _is_torch_array(m) or _is_torch_array(a): + return _stratified_coefficients_torch(m, a, L) + n_layers = len(a) if n_layers == 1: @@ -231,13 +531,19 @@ def harmonics( """ + if _is_torch_array(x): + return _harmonics_torch(x, L) + PI = np.zeros((L, *x.shape)) TAU = np.zeros((L, *x.shape)) - PI[0, :] = 1 - PI[1, :] = 3 * x - TAU[0, :] = x - TAU[1, :] = 6 * x * x - 3 + if L >= 1: + PI[0, :] = 1 + TAU[0, :] = x + + if L >= 2: + PI[1, :] = 3 * x + TAU[1, :] = 6 * x * x - 3 for i in range(3, L + 1): PI[i - 1] = ( From c6be836ff429359d1e87e0a00e5c762427c662e9 Mon Sep 17 00:00:00 2001 From: edudc Date: Tue, 12 May 2026 15:05:45 +0200 Subject: [PATCH 02/13] Move Mie array API polynomials to backend --- deeptrack/backend/mie.py | 321 ++++++++++++++++--------------- deeptrack/backend/polynomials.py | 128 ++++++++++++ 2 files changed, 292 insertions(+), 157 deletions(-) diff --git a/deeptrack/backend/mie.py b/deeptrack/backend/mie.py index bf4df6b34..2083a847c 100644 --- a/deeptrack/backend/mie.py +++ b/deeptrack/backend/mie.py @@ -37,10 +37,17 @@ from __future__ import annotations +import array_api_compat as apc import numpy as np from numpy.typing import NDArray from .polynomials import ( + _dricbesh_array_api, + _dricbesj_array_api, + _dricbesy_array_api, + _ricbesh_array_api, + _ricbesj_array_api, + _ricbesy_array_api, ricbesh, ricbesy, ricbesj, @@ -49,165 +56,138 @@ dricbesy, ) -try: - import torch - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False +def _first_array(*values): + """Return the first array API object in values.""" + for value in values: + if apc.is_array_api_obj(value): + return value -def _is_torch_array(x) -> bool: - """Return whether x is, or contains, a torch tensor.""" - - if not TORCH_AVAILABLE: - return False + if isinstance(value, (list, tuple)): + found = _first_array(*value) + if found is not None: + return found - if torch.is_tensor(x): - return True + return None - if isinstance(x, (list, tuple)): - return any(_is_torch_array(v) for v in x) - return False +def _array_api_namespace(*values): + """Return a non-NumPy array API namespace and reference array if present.""" + reference = _first_array(*values) -def _first_torch_array(*values): - """Return the first torch tensor in values, searching nested sequences.""" + if reference is None: + return None, None - for value in values: - if torch.is_tensor(value): - return value + namespace = apc.array_namespace(reference) - if isinstance(value, (list, tuple)): - found = _first_torch_array(*value) - if found is not None: - return found + if apc.is_numpy_namespace(namespace): + return None, None - return None + return namespace, reference -def _torch_complex_dtype(*values): - """Return a complex dtype compatible with the torch inputs.""" +def _complex_dtype(namespace, *values): + """Return a complex dtype compatible with the array inputs.""" for value in values: - if torch.is_tensor(value): - if value.dtype in (torch.float64, torch.complex128): - return torch.complex128 + if apc.is_array_api_obj(value): + if value.dtype in (namespace.float64, namespace.complex128): + return namespace.complex128 if isinstance(value, (list, tuple)): - dtype = _torch_complex_dtype(*value) - if dtype == torch.complex128: + dtype = _complex_dtype(namespace, *value) + if dtype == namespace.complex128: return dtype - return torch.complex64 + return namespace.complex64 -def _as_torch_scalar(value, dtype, device): - """Convert value to a scalar torch tensor on device with dtype.""" +def _asarray(value, namespace, dtype, reference): + """Convert value to an array on the same backend as reference.""" - if torch.is_tensor(value): - return value.to(dtype=dtype, device=device) + if apc.is_array_api_obj(value): + return namespace.astype(value, dtype) - return torch.as_tensor(value, dtype=dtype, device=device) + try: + return namespace.asarray( + value, dtype=dtype, device=apc.device(reference) + ) + except TypeError: + return namespace.asarray(value, dtype=dtype) -def _as_torch_vector(value, dtype, device): - """Convert a tensor or sequence of scalars to a one-dimensional tensor.""" +def _asarray_vector(value, namespace, dtype, reference): + """Convert a tensor or sequence of scalars to a one-dimensional array.""" - if torch.is_tensor(value): - return value.to(dtype=dtype, device=device).reshape(-1) + if apc.is_array_api_obj(value): + return namespace.reshape( + _asarray(value, namespace, dtype, reference), (-1,) + ) - return torch.stack( + return namespace.stack( [ - _as_torch_scalar(element, dtype=dtype, device=device).reshape(()) + namespace.reshape( + _asarray(element, namespace, dtype, reference), () + ) for element in value ] ) -def _ricbesj_torch(l: int, x): - """Differentiable torch Riccati-Bessel polynomial of the first kind.""" - - if l == 0: - return torch.sin(x) - - previous = torch.sin(x) - current = torch.sin(x) / x - torch.cos(x) - - for order in range(1, l): - previous, current = current, (2 * order + 1) / x * current - previous - - return current - - -def _dricbesj_torch(l: int, x): - """Differentiable torch derivative of ricbesj.""" - - return _ricbesj_torch(l - 1, x) - l / x * _ricbesj_torch(l, x) - - -def _ricbesy_torch(l: int, x): - """Differentiable torch Riccati-Bessel polynomial of the second kind.""" +def _empty(namespace, shape, dtype, reference): + """Create an empty array on the same backend as reference.""" - if l == 0: - return torch.cos(x) - - previous = torch.cos(x) - current = torch.cos(x) / x + torch.sin(x) - - for order in range(1, l): - previous, current = current, (2 * order + 1) / x * current - previous - - return current - - -def _dricbesy_torch(l: int, x): - """Differentiable torch derivative of ricbesy.""" - - return _ricbesy_torch(l - 1, x) - l / x * _ricbesy_torch(l, x) - - -def _ricbesh_torch(l: int, x): - """Differentiable torch Riccati-Bessel polynomial of the third kind.""" - - return _ricbesj_torch(l, x) - 1j * _ricbesy_torch(l, x) + try: + return namespace.empty( + shape, dtype=dtype, device=apc.device(reference) + ) + except TypeError: + return namespace.empty(shape, dtype=dtype) -def _dricbesh_torch(l: int, x): - """Differentiable torch derivative of ricbesh.""" +def _zeros(namespace, shape, dtype, reference): + """Create a zero array on the same backend as reference.""" - return _dricbesj_torch(l, x) - 1j * _dricbesy_torch(l, x) + try: + return namespace.zeros( + shape, dtype=dtype, device=apc.device(reference) + ) + except TypeError: + return namespace.zeros(shape, dtype=dtype) -def _coefficients_torch( - m: float | complex | "torch.Tensor", - a: float | "torch.Tensor", +def _coefficients_array_api( + m: float | complex, + a: float, L: int, -) -> tuple["torch.Tensor", "torch.Tensor"]: - """Torch implementation of Mie coefficients.""" + namespace, + reference, +): + """Array API implementation of Mie coefficients.""" - reference = _first_torch_array(m, a) - device = reference.device - dtype = _torch_complex_dtype(m, a) + dtype = _complex_dtype(namespace, m, a) - m = _as_torch_scalar(m, dtype=dtype, device=device) - a = _as_torch_scalar(a, dtype=dtype, device=device) + m = _asarray(m, namespace, dtype, reference) + a = _asarray(a, namespace, dtype, reference) if L == 0: - empty = torch.empty((0,), dtype=dtype, device=device) - return empty, empty.clone() + return ( + _empty(namespace, (0,), dtype, reference), + _empty(namespace, (0,), dtype, reference), + ) A = [] B = [] for l in range(1, L + 1): - Sx = _ricbesj_torch(l, a) - dSx = _dricbesj_torch(l, a) - Smx = _ricbesj_torch(l, m * a) - dSmx = _dricbesj_torch(l, m * a) - xix = _ricbesh_torch(l, a) - dxix = _dricbesh_torch(l, a) + Sx = _ricbesj_array_api(l, a, namespace) + dSx = _dricbesj_array_api(l, a, namespace) + Smx = _ricbesj_array_api(l, m * a, namespace) + dSmx = _dricbesj_array_api(l, m * a, namespace) + xix = _ricbesh_array_api(l, a, namespace) + dxix = _dricbesh_array_api(l, a, namespace) A.append( (m * Smx * dSx - Sx * dSmx) @@ -218,30 +198,34 @@ def _coefficients_torch( / (Smx * dxix - m * xix * dSmx) ) - return torch.stack(A), torch.stack(B) + return namespace.stack(A), namespace.stack(B) -def _stratified_coefficients_torch( - m: list[complex] | "torch.Tensor", - a: list[float] | "torch.Tensor", +def _stratified_coefficients_array_api( + m: list[complex], + a: list[float], L: int, -) -> tuple["torch.Tensor", "torch.Tensor"]: - """Torch implementation of stratified Mie coefficients.""" + namespace, + reference, +): + """Array API implementation of stratified Mie coefficients.""" - reference = _first_torch_array(m, a) - device = reference.device - dtype = _torch_complex_dtype(m, a) + dtype = _complex_dtype(namespace, m, a) - m = _as_torch_vector(m, dtype=dtype, device=device) - a = _as_torch_vector(a, dtype=dtype, device=device) - n_layers = a.numel() + m = _asarray_vector(m, namespace, dtype, reference) + a = _asarray_vector(a, namespace, dtype, reference) + n_layers = a.shape[0] if n_layers == 1: - return _coefficients_torch(m[0], a[0], L) + return _coefficients_array_api( + m[0], a[0], L, namespace, reference + ) if L == 0: - empty = torch.empty((0,), dtype=dtype, device=device) - return empty, empty.clone() + return ( + _empty(namespace, (0,), dtype, reference), + _empty(namespace, (0,), dtype, reference), + ) an = [] bn = [] @@ -249,7 +233,7 @@ def _stratified_coefficients_torch( for n in range(L): A_rows = [] C_rows = [] - zero = torch.zeros((), dtype=dtype, device=device) + zero = _zeros(namespace, (), dtype, reference) for i in range(2 * n_layers): for j in range(2 * n_layers): @@ -265,11 +249,17 @@ def _stratified_coefficients_torch( j < 2 * n_layers - 1 and (j == 0 or j % 2 == 1) ): - A_ij = _dricbesj_torch(n + 1, m[p] * a[q]) + A_ij = _dricbesj_array_api( + n + 1, m[p] * a[q], namespace + ) elif j % 2 == 0: - A_ij = _dricbesy_torch(n + 1, m[p] * a[q]) + A_ij = _dricbesy_array_api( + n + 1, m[p] * a[q], namespace + ) else: - A_ij = _dricbesj_torch(n + 1, a[q]) + A_ij = _dricbesj_array_api( + n + 1, a[q], namespace + ) if j != 2 * n_layers - 1: C_ij = m[p] * A_ij @@ -280,11 +270,17 @@ def _stratified_coefficients_torch( j < 2 * n_layers - 1 and (j == 0 or j % 2 == 1) ): - C_ij = _ricbesj_torch(n + 1, m[p] * a[q]) + C_ij = _ricbesj_array_api( + n + 1, m[p] * a[q], namespace + ) elif j % 2 == 0: - C_ij = _ricbesy_torch(n + 1, m[p] * a[q]) + C_ij = _ricbesy_array_api( + n + 1, m[p] * a[q], namespace + ) else: - C_ij = _ricbesj_torch(n + 1, a[q]) + C_ij = _ricbesj_array_api( + n + 1, a[q], namespace + ) if j != 2 * n_layers - 1: A_ij = m[p] * C_ij @@ -294,27 +290,30 @@ def _stratified_coefficients_torch( A_rows.append(A_ij) C_rows.append(C_ij) - A = torch.stack(A_rows).reshape(2 * n_layers, 2 * n_layers) - C = torch.stack(C_rows).reshape(2 * n_layers, 2 * n_layers) + shape = (2 * n_layers, 2 * n_layers) + A = namespace.reshape(namespace.stack(A_rows), shape) + C = namespace.reshape(namespace.stack(C_rows), shape) - B = A.clone() - B[-2, -1] = _dricbesh_torch(n + 1, a[-1]) - B[-1, -1] = _ricbesh_torch(n + 1, a[-1]) - an.append(torch.linalg.det(A) / torch.linalg.det(B)) + B = A * 1 + B[-2, -1] = _dricbesh_array_api(n + 1, a[-1], namespace) + B[-1, -1] = _ricbesh_array_api(n + 1, a[-1], namespace) + an.append(namespace.linalg.det(A) / namespace.linalg.det(B)) - D = C.clone() - D[-2, -1] = _dricbesh_torch(n + 1, a[-1]) - D[-1, -1] = _ricbesh_torch(n + 1, a[-1]) - bn.append(torch.linalg.det(C) / torch.linalg.det(D)) + D = C * 1 + D[-2, -1] = _dricbesh_array_api(n + 1, a[-1], namespace) + D[-1, -1] = _ricbesh_array_api(n + 1, a[-1], namespace) + bn.append(namespace.linalg.det(C) / namespace.linalg.det(D)) - return torch.stack(an), torch.stack(bn) + return namespace.stack(an), namespace.stack(bn) -def _harmonics_torch( - x: "torch.Tensor", +def _harmonics_array_api( + x, L: int, -) -> tuple["torch.Tensor", "torch.Tensor"]: - """Torch implementation of Mie harmonics.""" + namespace, + reference, +): + """Array API implementation of Mie harmonics.""" PI = [] TAU = [] @@ -322,12 +321,12 @@ def _harmonics_torch( if L == 0: shape = (0, *x.shape) return ( - torch.empty(shape, dtype=x.dtype, device=x.device), - torch.empty(shape, dtype=x.dtype, device=x.device), + _empty(namespace, shape, x.dtype, reference), + _empty(namespace, shape, x.dtype, reference), ) if L >= 1: - PI.append(torch.ones_like(x)) + PI.append(namespace.ones_like(x)) TAU.append(x) if L >= 2: @@ -341,7 +340,7 @@ def _harmonics_torch( ) TAU.append(i * x * PI[i - 1] - (i + 1) * PI[i - 2]) - return torch.stack(PI), torch.stack(TAU) + return namespace.stack(PI), namespace.stack(TAU) #TODO ***??*** revise coefficients - torch, docstring, unit test @@ -373,8 +372,10 @@ def coefficients( """ - if _is_torch_array(m) or _is_torch_array(a): - return _coefficients_torch(m, a, L) + namespace, reference = _array_api_namespace(m, a) + + if namespace is not None: + return _coefficients_array_api(m, a, L, namespace, reference) A = np.zeros((L,), dtype=np.complex128) B = np.zeros((L,), dtype=np.complex128) @@ -430,8 +431,12 @@ def stratified_coefficients( including) order L. """ - if _is_torch_array(m) or _is_torch_array(a): - return _stratified_coefficients_torch(m, a, L) + namespace, reference = _array_api_namespace(m, a) + + if namespace is not None: + return _stratified_coefficients_array_api( + m, a, L, namespace, reference + ) n_layers = len(a) @@ -531,8 +536,10 @@ def harmonics( """ - if _is_torch_array(x): - return _harmonics_torch(x, L) + namespace, reference = _array_api_namespace(x) + + if namespace is not None: + return _harmonics_array_api(x, L, namespace, reference) PI = np.zeros((L, *x.shape)) TAU = np.zeros((L, *x.shape)) diff --git a/deeptrack/backend/polynomials.py b/deeptrack/backend/polynomials.py index 64b304c0e..8bf9e9593 100644 --- a/deeptrack/backend/polynomials.py +++ b/deeptrack/backend/polynomials.py @@ -26,11 +26,115 @@ from __future__ import annotations +import array_api_compat as apc import numpy as np from numpy.typing import NDArray from scipy.special import jv, h1vp, yv +def _integer_order(l: int | float) -> int: + """Return l as an integer order supported by recurrence formulas.""" + + order = int(l) + + if order != l or order < 0: + raise ValueError( + "Array API Riccati-Bessel functions require non-negative integer " + "orders." + ) + + return order + + +def _array_namespace(x): + """Return the array namespace for x, or None for Python scalars.""" + + try: + return apc.array_namespace(x) + except TypeError: + return None + + +def _ricbesj_array_api(l: int | float, x, namespace=None): + """Array-API Riccati-Bessel polynomial of the first kind.""" + + l = _integer_order(l) + xp = namespace or apc.array_namespace(x) + + if l == 0: + return xp.sin(x) + + previous = xp.sin(x) + current = xp.sin(x) / x - xp.cos(x) + + for order in range(1, l): + previous, current = current, (2 * order + 1) / x * current - previous + + return current + + +def _dricbesj_array_api(l: int | float, x, namespace=None): + """Array-API derivative of ricbesj.""" + + l = _integer_order(l) + xp = namespace or apc.array_namespace(x) + + if l == 0: + return xp.cos(x) + + return ( + _ricbesj_array_api(l - 1, x, xp) + - l / x * _ricbesj_array_api(l, x, xp) + ) + + +def _ricbesy_array_api(l: int | float, x, namespace=None): + """Array-API Riccati-Bessel polynomial of the second kind.""" + + l = _integer_order(l) + xp = namespace or apc.array_namespace(x) + + if l == 0: + return xp.cos(x) + + previous = xp.cos(x) + current = xp.cos(x) / x + xp.sin(x) + + for order in range(1, l): + previous, current = current, (2 * order + 1) / x * current - previous + + return current + + +def _dricbesy_array_api(l: int | float, x, namespace=None): + """Array-API derivative of ricbesy.""" + + l = _integer_order(l) + xp = namespace or apc.array_namespace(x) + + if l == 0: + return -xp.sin(x) + + return ( + _ricbesy_array_api(l - 1, x, xp) + - l / x * _ricbesy_array_api(l, x, xp) + ) + + +def _ricbesh_array_api(l: int | float, x, namespace=None): + """Array-API Riccati-Bessel polynomial of the third kind.""" + + xp = namespace or apc.array_namespace(x) + return _ricbesj_array_api(l, x, xp) - 1j * _ricbesy_array_api(l, x, xp) + + +def _dricbesh_array_api(l: int | float, x, namespace=None): + """Array-API derivative of ricbesh.""" + + xp = namespace or apc.array_namespace(x) + return _dricbesj_array_api(l, x, xp) - 1j * _dricbesy_array_api(l, x, xp) + + #TODO ***??*** revise besselj - torch, docstring, unit test def besselj( l: int | float, @@ -148,6 +252,10 @@ def ricbesj( """ + namespace = _array_namespace(x) + if namespace is not None and not apc.is_numpy_namespace(namespace): + return _ricbesj_array_api(l, x, namespace) + return np.sqrt(np.pi * x / 2) * besselj(l + 0.5, x) @@ -172,6 +280,10 @@ def dricbesj( """ + namespace = _array_namespace(x) + if namespace is not None and not apc.is_numpy_namespace(namespace): + return _dricbesj_array_api(l, x, namespace) + return 0.5 * np.sqrt(np.pi / x / 2) * besselj(l + 0.5, x) + np.sqrt( np.pi * x / 2 ) * dbesselj(l + 0.5, x) @@ -198,6 +310,10 @@ def ricbesy( """ + namespace = _array_namespace(x) + if namespace is not None and not apc.is_numpy_namespace(namespace): + return _ricbesy_array_api(l, x, namespace) + return -np.sqrt(np.pi * x / 2) * bessely(l + 0.5, x) @@ -222,6 +338,10 @@ def dricbesy( """ + namespace = _array_namespace(x) + if namespace is not None and not apc.is_numpy_namespace(namespace): + return _dricbesy_array_api(l, x, namespace) + return -0.5 * np.sqrt(np.pi / 2 / x) * yv(l + 0.5, x) - np.sqrt( np.pi * x / 2 ) * dbessely(l + 0.5, x) @@ -248,6 +368,10 @@ def ricbesh( """ + namespace = _array_namespace(x) + if namespace is not None and not apc.is_numpy_namespace(namespace): + return _ricbesh_array_api(l, x, namespace) + return np.sqrt(np.pi * x / 2) * h1vp(l + 0.5, x, False) @@ -272,6 +396,10 @@ def dricbesh( """ + namespace = _array_namespace(x) + if namespace is not None and not apc.is_numpy_namespace(namespace): + return _dricbesh_array_api(l, x, namespace) + xi = 0.5 * np.sqrt(np.pi / 2 / x) * h1vp(l + 0.5, x, False) + np.sqrt( np.pi * x / 2 ) * h1vp(l + 0.5, x, True) From 3039d2514e0c552a83ed75be5b4e13064a4bbbfb Mon Sep 17 00:00:00 2001 From: edudc Date: Tue, 12 May 2026 15:06:22 +0200 Subject: [PATCH 03/13] Test torch Mie backend autodiff --- tests/backend/test_mie.py | 110 +++++++++++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 1 deletion(-) diff --git a/tests/backend/test_mie.py b/tests/backend/test_mie.py index 1c654dd35..7e9bf009c 100644 --- a/tests/backend/test_mie.py +++ b/tests/backend/test_mie.py @@ -10,7 +10,10 @@ import numpy as np -from deeptrack.backend import mie +from deeptrack.backend import mie, TORCH_AVAILABLE + +if TORCH_AVAILABLE: + import torch class TestMie(unittest.TestCase): @@ -113,5 +116,110 @@ def test_harmonics(self): self.assertTrue(np.allclose(TAU, TAU_expected)) +@unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.") +class TestMieTorch(unittest.TestCase): + + def test_coefficients_matches_numpy_and_autodiff(self): + m = 1.5 + 0.01j + a_np = 0.5 + L = 5 + + A_expected, B_expected = mie.coefficients(m, a_np, L) + + a = torch.tensor(a_np, dtype=torch.float64, requires_grad=True) + A, B = mie.coefficients(m, a, L) + + self.assertIsInstance(A, torch.Tensor) + self.assertIsInstance(B, torch.Tensor) + self.assertEqual(A.shape, (L,)) + self.assertEqual(B.shape, (L,)) + + self.assertTrue( + np.allclose(A.detach().numpy(), A_expected, rtol=1e-10, atol=1e-10) + ) + self.assertTrue( + np.allclose(B.detach().numpy(), B_expected, rtol=1e-10, atol=1e-10) + ) + + loss = torch.abs(A).sum() + torch.abs(B).sum() + loss.backward() + + self.assertIsNotNone(a.grad) + self.assertTrue(torch.isfinite(a.grad)) + self.assertGreater(abs(float(a.grad)), 0) + + def test_coefficients_refractive_index_autodiff(self): + m = torch.tensor(1.5, dtype=torch.float64, requires_grad=True) + a = torch.tensor(0.5, dtype=torch.float64, requires_grad=True) + + A, B = mie.coefficients(m, a, 5) + + loss = torch.abs(A).sum() + torch.abs(B).sum() + loss.backward() + + self.assertIsNotNone(m.grad) + self.assertIsNotNone(a.grad) + self.assertTrue(torch.isfinite(m.grad)) + self.assertTrue(torch.isfinite(a.grad)) + self.assertGreater(abs(float(m.grad)), 0) + self.assertGreater(abs(float(a.grad)), 0) + + def test_stratified_coefficients_matches_numpy_and_autodiff(self): + m = [1.5 + 0.01j, 1.2 + 0.02j] + a_np = [0.5, 0.3] + L = 5 + + an_expected, bn_expected = mie.stratified_coefficients(m, a_np, L) + + a = torch.tensor(a_np, dtype=torch.float64, requires_grad=True) + an, bn = mie.stratified_coefficients(m, a, L) + + self.assertIsInstance(an, torch.Tensor) + self.assertIsInstance(bn, torch.Tensor) + self.assertEqual(an.shape, (L,)) + self.assertEqual(bn.shape, (L,)) + + self.assertTrue( + np.allclose( + an.detach().numpy(), an_expected, rtol=1e-10, atol=1e-10 + ) + ) + self.assertTrue( + np.allclose( + bn.detach().numpy(), bn_expected, rtol=1e-10, atol=1e-10 + ) + ) + + loss = torch.abs(an).sum() + torch.abs(bn).sum() + loss.backward() + + self.assertIsNotNone(a.grad) + self.assertTrue(torch.isfinite(a.grad).all()) + self.assertGreater(float(torch.linalg.vector_norm(a.grad)), 0) + + def test_harmonics_matches_numpy_and_autodiff(self): + x_np = np.array([0.4]) + L = 4 + PI_expected, TAU_expected = mie.harmonics(x_np, L) + + x = torch.tensor(x_np, dtype=torch.float64, requires_grad=True) + PI, TAU = mie.harmonics(x, L) + + self.assertIsInstance(PI, torch.Tensor) + self.assertIsInstance(TAU, torch.Tensor) + self.assertEqual(PI.shape, (L, 1)) + self.assertEqual(TAU.shape, (L, 1)) + + self.assertTrue(np.allclose(PI.detach().numpy(), PI_expected)) + self.assertTrue(np.allclose(TAU.detach().numpy(), TAU_expected)) + + loss = PI.sum() + TAU.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertTrue(torch.isfinite(x.grad).all()) + self.assertGreater(float(torch.linalg.vector_norm(x.grad)), 0) + + if __name__ == "__main__": unittest.main() From 1ab6a98c14a47a32771ac3b0f773c755500a25a0 Mon Sep 17 00:00:00 2001 From: edudc Date: Tue, 12 May 2026 16:07:51 +0200 Subject: [PATCH 04/13] Simplify Mie backend xp usage --- deeptrack/backend/mie.py | 476 ++++++++++++-------------------------- tests/backend/test_mie.py | 136 ++++++----- 2 files changed, 224 insertions(+), 388 deletions(-) diff --git a/deeptrack/backend/mie.py b/deeptrack/backend/mie.py index 2083a847c..368704dad 100644 --- a/deeptrack/backend/mie.py +++ b/deeptrack/backend/mie.py @@ -41,13 +41,8 @@ import numpy as np from numpy.typing import NDArray +from ._config import config, xp from .polynomials import ( - _dricbesh_array_api, - _dricbesj_array_api, - _dricbesy_array_api, - _ricbesh_array_api, - _ricbesj_array_api, - _ricbesy_array_api, ricbesh, ricbesy, ricbesj, @@ -57,290 +52,93 @@ ) -def _first_array(*values): - """Return the first array API object in values.""" +def _iter_arrays(*values): + """Yield array API objects from values, including nested sequences.""" for value in values: if apc.is_array_api_obj(value): - return value - - if isinstance(value, (list, tuple)): - found = _first_array(*value) - if found is not None: - return found - - return None - + yield value + elif isinstance(value, (list, tuple)): + yield from _iter_arrays(*value) -def _array_api_namespace(*values): - """Return a non-NumPy array API namespace and reference array if present.""" - reference = _first_array(*values) +def _first_array(*values): + """Return the first array API object in values, if any.""" - if reference is None: - return None, None + return next(_iter_arrays(*values), None) - namespace = apc.array_namespace(reference) - if apc.is_numpy_namespace(namespace): - return None, None +def _complex_dtype(*values): + """Return the complex dtype to use for the current xp backend.""" - return namespace, reference + for value in _iter_arrays(*values): + if value.dtype in (xp.float64, xp.complex128): + return xp.complex128 + return xp.get_complex_dtype() -def _complex_dtype(namespace, *values): - """Return a complex dtype compatible with the array inputs.""" - for value in values: - if apc.is_array_api_obj(value): - if value.dtype in (namespace.float64, namespace.complex128): - return namespace.complex128 +def _asarray(value, dtype=None, reference=None): + """Convert value with xp without detaching existing arrays.""" - if isinstance(value, (list, tuple)): - dtype = _complex_dtype(namespace, *value) - if dtype == namespace.complex128: - return dtype + is_current_backend_array = ( + config.get_backend() == "numpy" + and apc.is_numpy_array(value) + or config.get_backend() == "torch" + and apc.is_torch_array(value) + ) - return namespace.complex64 + if is_current_backend_array: + return xp.astype(value, dtype) if dtype is not None else value + kwargs = {} -def _asarray(value, namespace, dtype, reference): - """Convert value to an array on the same backend as reference.""" + if dtype is not None: + kwargs["dtype"] = dtype - if apc.is_array_api_obj(value): - return namespace.astype(value, dtype) + if reference is not None: + try: + kwargs["device"] = apc.device(reference) + except TypeError: + pass try: - return namespace.asarray( - value, dtype=dtype, device=apc.device(reference) - ) + return xp.asarray(value, **kwargs) except TypeError: - return namespace.asarray(value, dtype=dtype) + kwargs.pop("device", None) + return xp.asarray(value, **kwargs) -def _asarray_vector(value, namespace, dtype, reference): +def _asarray_vector(value, dtype=None, reference=None): """Convert a tensor or sequence of scalars to a one-dimensional array.""" if apc.is_array_api_obj(value): - return namespace.reshape( - _asarray(value, namespace, dtype, reference), (-1,) - ) + return xp.reshape(_asarray(value, dtype, reference), (-1,)) - return namespace.stack( + return xp.stack( [ - namespace.reshape( - _asarray(element, namespace, dtype, reference), () - ) + xp.reshape(_asarray(element, dtype, reference), ()) for element in value ] ) -def _empty(namespace, shape, dtype, reference): - """Create an empty array on the same backend as reference.""" - - try: - return namespace.empty( - shape, dtype=dtype, device=apc.device(reference) - ) - except TypeError: - return namespace.empty(shape, dtype=dtype) +def _zeros(shape, dtype, reference=None): + """Create a zero array on the same backend as reference.""" + kwargs = {"dtype": dtype} -def _zeros(namespace, shape, dtype, reference): - """Create a zero array on the same backend as reference.""" + if reference is not None: + try: + kwargs["device"] = apc.device(reference) + except TypeError: + pass try: - return namespace.zeros( - shape, dtype=dtype, device=apc.device(reference) - ) + return xp.zeros(shape, **kwargs) except TypeError: - return namespace.zeros(shape, dtype=dtype) - - -def _coefficients_array_api( - m: float | complex, - a: float, - L: int, - namespace, - reference, -): - """Array API implementation of Mie coefficients.""" - - dtype = _complex_dtype(namespace, m, a) - - m = _asarray(m, namespace, dtype, reference) - a = _asarray(a, namespace, dtype, reference) - - if L == 0: - return ( - _empty(namespace, (0,), dtype, reference), - _empty(namespace, (0,), dtype, reference), - ) - - A = [] - B = [] - - for l in range(1, L + 1): - Sx = _ricbesj_array_api(l, a, namespace) - dSx = _dricbesj_array_api(l, a, namespace) - Smx = _ricbesj_array_api(l, m * a, namespace) - dSmx = _dricbesj_array_api(l, m * a, namespace) - xix = _ricbesh_array_api(l, a, namespace) - dxix = _dricbesh_array_api(l, a, namespace) - - A.append( - (m * Smx * dSx - Sx * dSmx) - / (m * Smx * dxix - xix * dSmx) - ) - B.append( - (Smx * dSx - m * Sx * dSmx) - / (Smx * dxix - m * xix * dSmx) - ) - - return namespace.stack(A), namespace.stack(B) - - -def _stratified_coefficients_array_api( - m: list[complex], - a: list[float], - L: int, - namespace, - reference, -): - """Array API implementation of stratified Mie coefficients.""" - - dtype = _complex_dtype(namespace, m, a) - - m = _asarray_vector(m, namespace, dtype, reference) - a = _asarray_vector(a, namespace, dtype, reference) - n_layers = a.shape[0] - - if n_layers == 1: - return _coefficients_array_api( - m[0], a[0], L, namespace, reference - ) - - if L == 0: - return ( - _empty(namespace, (0,), dtype, reference), - _empty(namespace, (0,), dtype, reference), - ) - - an = [] - bn = [] - - for n in range(L): - A_rows = [] - C_rows = [] - zero = _zeros(namespace, (), dtype, reference) - - for i in range(2 * n_layers): - for j in range(2 * n_layers): - p = (j + 1) // 2 - q = i // 2 - - A_ij = zero - C_ij = zero - - if (p - q == 0) or (p - q == 1): - if i % 2 == 0: - if ( - j < 2 * n_layers - 1 - and (j == 0 or j % 2 == 1) - ): - A_ij = _dricbesj_array_api( - n + 1, m[p] * a[q], namespace - ) - elif j % 2 == 0: - A_ij = _dricbesy_array_api( - n + 1, m[p] * a[q], namespace - ) - else: - A_ij = _dricbesj_array_api( - n + 1, a[q], namespace - ) - - if j != 2 * n_layers - 1: - C_ij = m[p] * A_ij - else: - C_ij = A_ij - else: - if ( - j < 2 * n_layers - 1 - and (j == 0 or j % 2 == 1) - ): - C_ij = _ricbesj_array_api( - n + 1, m[p] * a[q], namespace - ) - elif j % 2 == 0: - C_ij = _ricbesy_array_api( - n + 1, m[p] * a[q], namespace - ) - else: - C_ij = _ricbesj_array_api( - n + 1, a[q], namespace - ) - - if j != 2 * n_layers - 1: - A_ij = m[p] * C_ij - else: - A_ij = C_ij - - A_rows.append(A_ij) - C_rows.append(C_ij) - - shape = (2 * n_layers, 2 * n_layers) - A = namespace.reshape(namespace.stack(A_rows), shape) - C = namespace.reshape(namespace.stack(C_rows), shape) - - B = A * 1 - B[-2, -1] = _dricbesh_array_api(n + 1, a[-1], namespace) - B[-1, -1] = _ricbesh_array_api(n + 1, a[-1], namespace) - an.append(namespace.linalg.det(A) / namespace.linalg.det(B)) - - D = C * 1 - D[-2, -1] = _dricbesh_array_api(n + 1, a[-1], namespace) - D[-1, -1] = _ricbesh_array_api(n + 1, a[-1], namespace) - bn.append(namespace.linalg.det(C) / namespace.linalg.det(D)) - - return namespace.stack(an), namespace.stack(bn) - - -def _harmonics_array_api( - x, - L: int, - namespace, - reference, -): - """Array API implementation of Mie harmonics.""" - - PI = [] - TAU = [] - - if L == 0: - shape = (0, *x.shape) - return ( - _empty(namespace, shape, x.dtype, reference), - _empty(namespace, shape, x.dtype, reference), - ) - - if L >= 1: - PI.append(namespace.ones_like(x)) - TAU.append(x) - - if L >= 2: - PI.append(3 * x) - TAU.append(6 * x * x - 3) - - for i in range(3, L + 1): - PI.append( - (2 * i - 1) / (i - 1) * x * PI[i - 2] - - i / (i - 1) * PI[i - 3] - ) - TAU.append(i * x * PI[i - 1] - (i + 1) * PI[i - 2]) - - return namespace.stack(PI), namespace.stack(TAU) + kwargs.pop("device", None) + return xp.zeros(shape, **kwargs) #TODO ***??*** revise coefficients - torch, docstring, unit test @@ -372,13 +170,19 @@ def coefficients( """ - namespace, reference = _array_api_namespace(m, a) + dtype = _complex_dtype(m, a) + reference = _first_array(m, a) + m = _asarray(m, dtype=dtype, reference=reference) + a = _asarray(a, dtype=dtype, reference=reference) - if namespace is not None: - return _coefficients_array_api(m, a, L, namespace, reference) + if L == 0: + return ( + _zeros((0,), dtype=dtype, reference=reference), + _zeros((0,), dtype=dtype, reference=reference), + ) - A = np.zeros((L,), dtype=np.complex128) - B = np.zeros((L,), dtype=np.complex128) + A = [] + B = [] for l in range(1, L + 1): Sx = ricbesj(l, a) @@ -388,18 +192,18 @@ def coefficients( xix = ricbesh(l, a) dxix = dricbesh(l, a) - A[l - 1] = ( - (m * Smx * dSx - Sx * dSmx) - / + A.append( + (m * Smx * dSx - Sx * dSmx) + / (m * Smx * dxix - xix * dSmx) ) - B[l - 1] = ( - (Smx * dSx - m * Sx * dSmx) - / + B.append( + (Smx * dSx - m * Sx * dSmx) + / (Smx * dxix - m * xix * dSmx) ) - return A, B + return xp.stack(A), xp.stack(B) #TODO ***??*** revise stratified_coefficients - torch, docstring, unit test @@ -431,73 +235,86 @@ def stratified_coefficients( including) order L. """ - namespace, reference = _array_api_namespace(m, a) - - if namespace is not None: - return _stratified_coefficients_array_api( - m, a, L, namespace, reference - ) - - n_layers = len(a) + dtype = _complex_dtype(m, a) + reference = _first_array(m, a) + m = _asarray_vector(m, dtype=dtype, reference=reference) + a = _asarray_vector(a, dtype=dtype, reference=reference) + n_layers = a.shape[0] if n_layers == 1: return coefficients(m[0], a[0], L) - an = np.zeros((L,), dtype=np.complex128) - bn = np.zeros((L,), dtype=np.complex128) + if L == 0: + return ( + _zeros((0,), dtype=dtype, reference=reference), + _zeros((0,), dtype=dtype, reference=reference), + ) + + an = [] + bn = [] for n in range(L): - A = np.zeros((2 * n_layers, 2 * n_layers), dtype=np.complex128) - C = np.zeros((2 * n_layers, 2 * n_layers), dtype=np.complex128) + A_rows = [] + C_rows = [] + zero = _zeros((), dtype=dtype, reference=reference) for i in range(2 * n_layers): for j in range(2 * n_layers): - p = np.floor((j + 1) / 2).astype(np.int32) - q = np.floor((i / 2)).astype(np.int32) - - if not ((p - q == 0) or (p - q == 1)): - continue - - if np.mod(i, 2) == 0: - if (j < 2 * n_layers - 1) and ((j == 0) or - (np.mod(j, 2) == 1)): - A[i, j] = dricbesj(n + 1, m[p] * a[q]) - elif np.mod(j, 2) == 0: - A[i, j] = dricbesy(n + 1, m[p] * a[q]) - else: - A[i, j] = dricbesj(n + 1, a[q]) - - C[i, j] = ( - m[p] * A[i, j] - if j != 2 * n_layers - 1 - else A[i, j] - ) - else: - if (j < 2 * n_layers - 1) and ((j == 0) or - (np.mod(j, 2) == 1)): - C[i, j] = ricbesj(n + 1, m[p] * a[q]) - elif np.mod(j, 2) == 0: - C[i, j] = ricbesy(n + 1, m[p] * a[q]) + p = (j + 1) // 2 + q = i // 2 + A_ij = zero + C_ij = zero + + if (p - q == 0) or (p - q == 1): + if i % 2 == 0: + if ( + j < 2 * n_layers - 1 + and (j == 0 or j % 2 == 1) + ): + A_ij = dricbesj(n + 1, m[p] * a[q]) + elif j % 2 == 0: + A_ij = dricbesy(n + 1, m[p] * a[q]) + else: + A_ij = dricbesj(n + 1, a[q]) + + if j != 2 * n_layers - 1: + C_ij = m[p] * A_ij + else: + C_ij = A_ij else: - C[i, j] = ricbesj(n + 1, a[q]) + if ( + j < 2 * n_layers - 1 + and (j == 0 or j % 2 == 1) + ): + C_ij = ricbesj(n + 1, m[p] * a[q]) + elif j % 2 == 0: + C_ij = ricbesy(n + 1, m[p] * a[q]) + else: + C_ij = ricbesj(n + 1, a[q]) - A[i, j] = ( - m[p] * C[i, j] - if j != 2 * n_layers - 1 - else C[i, j] - ) + if j != 2 * n_layers - 1: + A_ij = m[p] * C_ij + else: + A_ij = C_ij - B = A.copy() + A_rows.append(A_ij) + C_rows.append(C_ij) + + shape = (2 * n_layers, 2 * n_layers) + A = xp.reshape(xp.stack(A_rows), shape) + C = xp.reshape(xp.stack(C_rows), shape) + + B = A * 1 B[-2, -1] = dricbesh(n + 1, a[-1]) B[-1, -1] = ricbesh(n + 1, a[-1]) - an[n] = np.linalg.det(A) / np.linalg.det(B) + an.append(xp.linalg.det(A) / xp.linalg.det(B)) - D = C.copy() + D = C * 1 D[-2, -1] = dricbesh(n + 1, a[-1]) D[-1, -1] = ricbesh(n + 1, a[-1]) - bn[n] = np.linalg.det(C) / np.linalg.det(D) + bn.append(xp.linalg.det(C) / xp.linalg.det(D)) - return an, bn + return xp.stack(an), xp.stack(bn) #TODO ***??*** revise harmonics - torch, docstring, unit test @@ -536,26 +353,31 @@ def harmonics( """ - namespace, reference = _array_api_namespace(x) + x = _asarray(x) + reference = _first_array(x) - if namespace is not None: - return _harmonics_array_api(x, L, namespace, reference) + if L == 0: + return ( + _zeros((0, *x.shape), dtype=x.dtype, reference=reference), + _zeros((0, *x.shape), dtype=x.dtype, reference=reference), + ) - PI = np.zeros((L, *x.shape)) - TAU = np.zeros((L, *x.shape)) + PI = [] + TAU = [] if L >= 1: - PI[0, :] = 1 - TAU[0, :] = x + PI.append(xp.ones_like(x)) + TAU.append(x) if L >= 2: - PI[1, :] = 3 * x - TAU[1, :] = 6 * x * x - 3 + PI.append(3 * x) + TAU.append(6 * x * x - 3) for i in range(3, L + 1): - PI[i - 1] = ( - (2 * i - 1) / (i - 1) * x * PI[i - 2] - i / (i - 1) * PI[i - 3] + PI.append( + (2 * i - 1) / (i - 1) * x * PI[i - 2] + - i / (i - 1) * PI[i - 3] ) - TAU[i - 1] = i * x * PI[i - 1] - (i + 1) * PI[i - 2] + TAU.append(i * x * PI[i - 1] - (i + 1) * PI[i - 2]) - return PI, TAU + return xp.stack(PI), xp.stack(TAU) diff --git a/tests/backend/test_mie.py b/tests/backend/test_mie.py index 7e9bf009c..de8b3da36 100644 --- a/tests/backend/test_mie.py +++ b/tests/backend/test_mie.py @@ -10,7 +10,7 @@ import numpy as np -from deeptrack.backend import mie, TORCH_AVAILABLE +from deeptrack.backend import config, mie, TORCH_AVAILABLE if TORCH_AVAILABLE: import torch @@ -18,6 +18,9 @@ class TestMie(unittest.TestCase): + def setUp(self): + config.set_backend("numpy") + def test_coefficients(self): m = 1.5 + 0.01j a = 0.5 @@ -119,6 +122,9 @@ def test_harmonics(self): @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.") class TestMieTorch(unittest.TestCase): + def setUp(self): + config.set_backend("numpy") + def test_coefficients_matches_numpy_and_autodiff(self): m = 1.5 + 0.01j a_np = 0.5 @@ -126,43 +132,49 @@ def test_coefficients_matches_numpy_and_autodiff(self): A_expected, B_expected = mie.coefficients(m, a_np, L) - a = torch.tensor(a_np, dtype=torch.float64, requires_grad=True) - A, B = mie.coefficients(m, a, L) + with config.with_backend("torch"): + a = torch.tensor(a_np, dtype=torch.float64, requires_grad=True) + A, B = mie.coefficients(m, a, L) - self.assertIsInstance(A, torch.Tensor) - self.assertIsInstance(B, torch.Tensor) - self.assertEqual(A.shape, (L,)) - self.assertEqual(B.shape, (L,)) + self.assertIsInstance(A, torch.Tensor) + self.assertIsInstance(B, torch.Tensor) + self.assertEqual(A.shape, (L,)) + self.assertEqual(B.shape, (L,)) - self.assertTrue( - np.allclose(A.detach().numpy(), A_expected, rtol=1e-10, atol=1e-10) - ) - self.assertTrue( - np.allclose(B.detach().numpy(), B_expected, rtol=1e-10, atol=1e-10) - ) + self.assertTrue( + np.allclose( + A.detach().numpy(), A_expected, rtol=1e-10, atol=1e-10 + ) + ) + self.assertTrue( + np.allclose( + B.detach().numpy(), B_expected, rtol=1e-10, atol=1e-10 + ) + ) - loss = torch.abs(A).sum() + torch.abs(B).sum() - loss.backward() + loss = torch.abs(A).sum() + torch.abs(B).sum() + loss.backward() - self.assertIsNotNone(a.grad) - self.assertTrue(torch.isfinite(a.grad)) - self.assertGreater(abs(float(a.grad)), 0) + self.assertIsNotNone(a.grad) + self.assertTrue(torch.isfinite(a.grad)) + self.assertGreater(abs(float(a.grad)), 0) def test_coefficients_refractive_index_autodiff(self): - m = torch.tensor(1.5, dtype=torch.float64, requires_grad=True) - a = torch.tensor(0.5, dtype=torch.float64, requires_grad=True) + with config.with_backend("torch"): + m = torch.tensor(1.5, dtype=torch.float64, requires_grad=True) + a = torch.tensor(0.5, dtype=torch.float64, requires_grad=True) - A, B = mie.coefficients(m, a, 5) + A, B = mie.coefficients(m, a, 5) - loss = torch.abs(A).sum() + torch.abs(B).sum() - loss.backward() + loss = torch.abs(A).sum() + torch.abs(B).sum() + loss.backward() - self.assertIsNotNone(m.grad) - self.assertIsNotNone(a.grad) - self.assertTrue(torch.isfinite(m.grad)) - self.assertTrue(torch.isfinite(a.grad)) - self.assertGreater(abs(float(m.grad)), 0) - self.assertGreater(abs(float(a.grad)), 0) + self.assertIsNotNone(m.grad) + self.assertIsNotNone(a.grad) + self.assertTrue(torch.isfinite(m.grad)) + self.assertTrue(torch.isfinite(a.grad)) + self.assertGreater(abs(float(m.grad)), 0) + self.assertGreater(abs(float(a.grad)), 0) def test_stratified_coefficients_matches_numpy_and_autodiff(self): m = [1.5 + 0.01j, 1.2 + 0.02j] @@ -171,54 +183,56 @@ def test_stratified_coefficients_matches_numpy_and_autodiff(self): an_expected, bn_expected = mie.stratified_coefficients(m, a_np, L) - a = torch.tensor(a_np, dtype=torch.float64, requires_grad=True) - an, bn = mie.stratified_coefficients(m, a, L) + with config.with_backend("torch"): + a = torch.tensor(a_np, dtype=torch.float64, requires_grad=True) + an, bn = mie.stratified_coefficients(m, a, L) - self.assertIsInstance(an, torch.Tensor) - self.assertIsInstance(bn, torch.Tensor) - self.assertEqual(an.shape, (L,)) - self.assertEqual(bn.shape, (L,)) + self.assertIsInstance(an, torch.Tensor) + self.assertIsInstance(bn, torch.Tensor) + self.assertEqual(an.shape, (L,)) + self.assertEqual(bn.shape, (L,)) - self.assertTrue( - np.allclose( - an.detach().numpy(), an_expected, rtol=1e-10, atol=1e-10 + self.assertTrue( + np.allclose( + an.detach().numpy(), an_expected, rtol=1e-10, atol=1e-10 + ) ) - ) - self.assertTrue( - np.allclose( - bn.detach().numpy(), bn_expected, rtol=1e-10, atol=1e-10 + self.assertTrue( + np.allclose( + bn.detach().numpy(), bn_expected, rtol=1e-10, atol=1e-10 + ) ) - ) - loss = torch.abs(an).sum() + torch.abs(bn).sum() - loss.backward() + loss = torch.abs(an).sum() + torch.abs(bn).sum() + loss.backward() - self.assertIsNotNone(a.grad) - self.assertTrue(torch.isfinite(a.grad).all()) - self.assertGreater(float(torch.linalg.vector_norm(a.grad)), 0) + self.assertIsNotNone(a.grad) + self.assertTrue(torch.isfinite(a.grad).all()) + self.assertGreater(float(torch.linalg.vector_norm(a.grad)), 0) def test_harmonics_matches_numpy_and_autodiff(self): x_np = np.array([0.4]) L = 4 PI_expected, TAU_expected = mie.harmonics(x_np, L) - x = torch.tensor(x_np, dtype=torch.float64, requires_grad=True) - PI, TAU = mie.harmonics(x, L) + with config.with_backend("torch"): + x = torch.tensor(x_np, dtype=torch.float64, requires_grad=True) + PI, TAU = mie.harmonics(x, L) - self.assertIsInstance(PI, torch.Tensor) - self.assertIsInstance(TAU, torch.Tensor) - self.assertEqual(PI.shape, (L, 1)) - self.assertEqual(TAU.shape, (L, 1)) + self.assertIsInstance(PI, torch.Tensor) + self.assertIsInstance(TAU, torch.Tensor) + self.assertEqual(PI.shape, (L, 1)) + self.assertEqual(TAU.shape, (L, 1)) - self.assertTrue(np.allclose(PI.detach().numpy(), PI_expected)) - self.assertTrue(np.allclose(TAU.detach().numpy(), TAU_expected)) + self.assertTrue(np.allclose(PI.detach().numpy(), PI_expected)) + self.assertTrue(np.allclose(TAU.detach().numpy(), TAU_expected)) - loss = PI.sum() + TAU.sum() - loss.backward() + loss = PI.sum() + TAU.sum() + loss.backward() - self.assertIsNotNone(x.grad) - self.assertTrue(torch.isfinite(x.grad).all()) - self.assertGreater(float(torch.linalg.vector_norm(x.grad)), 0) + self.assertIsNotNone(x.grad) + self.assertTrue(torch.isfinite(x.grad).all()) + self.assertGreater(float(torch.linalg.vector_norm(x.grad)), 0) if __name__ == "__main__": From 288c8ea469f546bcb5227c26c393e00b2efc6935 Mon Sep 17 00:00:00 2001 From: edudc Date: Wed, 13 May 2026 11:37:17 +0200 Subject: [PATCH 05/13] converted Mie scatterer setup, grids, masks, polarization math, FFTs, propagation use, and stratified coefficients to preserve torch tensors/autograd --- deeptrack/optical/scatterers.py | 288 ++++++++++++++++++++++---------- 1 file changed, 201 insertions(+), 87 deletions(-) diff --git a/deeptrack/optical/scatterers.py b/deeptrack/optical/scatterers.py index 613c48f7a..5a2f67f2d 100644 --- a/deeptrack/optical/scatterers.py +++ b/deeptrack/optical/scatterers.py @@ -187,7 +187,7 @@ get_active_scale, get_active_voxel_size, ) -from deeptrack.backend import mie, TORCH_AVAILABLE, xp +from deeptrack.backend import config, mie, TORCH_AVAILABLE, xp from deeptrack.optical.math import AveragePooling, pad_image_to_fft from deeptrack.features import ( Feature, @@ -201,6 +201,37 @@ import torch +def _asarray(value, dtype=None): + """Convert values through xp while preserving existing tensor gradients.""" + + is_current_backend_array = ( + config.get_backend() == "numpy" + and apc.is_numpy_array(value) + or config.get_backend() == "torch" + and apc.is_torch_array(value) + ) + + if is_current_backend_array: + return xp.astype(value, dtype) if dtype is not None else value + + if dtype is not None: + return xp.asarray(value, dtype=dtype) + return xp.asarray(value) + + +def _asarray_vector(value, dtype=None): + """Convert a vector-like value without detaching tensor elements.""" + + if isinstance(value, (list, tuple)) and any( + apc.is_array_api_obj(element) for element in value + ): + return xp.stack( + [xp.reshape(_asarray(element, dtype), ()) for element in value] + ) + + return xp.reshape(_asarray(value, dtype), (-1,)) + + __all__ = [ "Scatterer", "PointParticle", @@ -1418,20 +1449,27 @@ def _process_properties( if properties["L"] == "auto": try: + radius_for_l = properties["radius"] + if TORCH_AVAILABLE and torch.is_tensor(radius_for_l): + radius_for_l = radius_for_l.detach().cpu().numpy() + v = ( 2 * np.pi - * np.max(properties["radius"]) + * np.max(radius_for_l) / properties["wavelength"] ) properties["L"] = int(np.floor((v + 4 * (v ** (1 / 3)) + 1))) - except (ValueError, TypeError): + except (ValueError, TypeError, RuntimeError): pass if properties["collection_angle"] == "auto": - properties["collection_angle"] = np.arcsin( + collection_arg = ( properties["NA"] / properties["refractive_index_medium"] ) + if config.get_backend() == "torch": + collection_arg = _asarray(collection_arg, dtype=xp.float64) + properties["collection_angle"] = xp.asin(collection_arg) if properties["offset_z"] == "auto": size = ( @@ -1443,11 +1481,17 @@ def _process_properties( # offset_z should be calculated with the physical size of the image # not the fft-padded size min_edge_size = np.min([xSize, ySize]) + collection_angle = properties["collection_angle"] + if config.get_backend() == "torch": + collection_angle = _asarray( + collection_angle, + dtype=xp.float64, + ) properties["offset_z"] = ( min_edge_size * 0.45 * min(get_active_voxel_size()[:2]) - / np.tan(properties["collection_angle"]) + / xp.tan(collection_angle) ) return properties @@ -1497,9 +1541,11 @@ def get_xy_grid( """ - x = np.arange(shape[0]) - shape[0] / 2 - y = np.arange(shape[1]) - shape[1] / 2 - return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij") + x = xp.arange(shape[0], dtype=xp.float64) + y = xp.arange(shape[1], dtype=xp.float64) + x = x - shape[0] / 2 + y = y - shape[1] / 2 + return xp.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij") def get_detector_mask( self: MieScatterer, @@ -1527,7 +1573,7 @@ def get_detector_mask( """ - return np.sqrt(X**2 + Y**2) < radius + return xp.sqrt(X**2 + Y**2) < radius def _plane_in_polar_coords_geometric( self: MieScatterer, @@ -1577,13 +1623,13 @@ def _plane_in_polar_coords_geometric( Z = plane_position[2] R2_squared = X**2 + Y**2 - R3 = np.sqrt(R2_squared + Z**2) + R3 = xp.sqrt(R2_squared + Z**2) cos_theta = Z / R3 - illumination_cos_theta = np.cos( - np.arccos(cos_theta) + illumination_angle + illumination_cos_theta = xp.cos( + xp.acos(cos_theta) + illumination_angle ) - phi = np.arctan2(Y, X) + phi = xp.atan2(Y, X) return R3, cos_theta, illumination_cos_theta, phi @@ -1633,18 +1679,24 @@ def _plane_in_polar_coords_hybrid( Z = plane_position[2] R2_squared = X**2 + Y**2 - R3 = np.sqrt(R2_squared + Z**2) + R3 = xp.sqrt(R2_squared + Z**2) - Q = np.sqrt(R2_squared) / voxel_size[0] ** 2 * 2 * np.pi / shape[0] + Q = xp.sqrt(R2_squared) / voxel_size[0] ** 2 * 2 * np.pi / shape[0] sin_theta = Q / (k) pupil_mask = sin_theta < 1 - cos_theta = np.zeros(sin_theta.shape) - cos_theta[pupil_mask] = np.sqrt(1 - sin_theta[pupil_mask] ** 2) + cos_theta = xp.sqrt( + xp.maximum(xp.zeros_like(sin_theta), 1 - sin_theta**2) + ) + cos_theta = xp.where( + pupil_mask, + cos_theta, + xp.zeros_like(cos_theta), + ) - illumination_cos_theta = np.cos( - np.arccos(cos_theta) + illumination_angle + illumination_cos_theta = xp.cos( + xp.acos(cos_theta) + illumination_angle ) - phi = np.arctan2(Y, X) + phi = xp.atan2(Y, X) return R3, cos_theta, illumination_cos_theta, phi, pupil_mask @@ -1683,33 +1735,37 @@ def _polarization_coefficients( """ - if isinstance(input_polarization, (float, int, str, Quantity)): - if isinstance(input_polarization, Quantity): - input_polarization = input_polarization.to("rad").magnitude - - if isinstance(input_polarization, (float, int)): - S1_coef = np.sin(phi + input_polarization) - S2_coef = np.cos(phi + input_polarization) + if isinstance(input_polarization, Quantity): + input_polarization = input_polarization.to("rad").magnitude - elif ( - isinstance(input_polarization, str) - and input_polarization == "circular" - ): - S1_coef = 1 / np.sqrt(2) - S2_coef = 1j / np.sqrt(2) - else: + if isinstance(input_polarization, str): + if input_polarization != "circular": raise TypeError( f"Unsupported input_polarization: {input_polarization}" ) + S1_coef = 1 / np.sqrt(2) + S2_coef = 1j / np.sqrt(2) + else: + input_polarization = _asarray( + input_polarization, + dtype=xp.float64, + ) + S1_coef = xp.sin(phi + input_polarization) + S2_coef = xp.cos(phi + input_polarization) - if isinstance(output_polarization, (float, int, Quantity)): - if isinstance(output_polarization, Quantity): - output_polarization = output_polarization.to("rad").magnitude + if isinstance(output_polarization, Quantity): + output_polarization = output_polarization.to("rad").magnitude - S1_coef *= np.sin(phi + output_polarization) - S2_coef *= ( - np.cos(phi + output_polarization) * illumination_cos_theta - ) + output_polarization = _asarray( + output_polarization, + dtype=xp.float64, + ) + S1_coef = S1_coef * xp.sin(phi + output_polarization) + S2_coef = ( + S2_coef + * xp.cos(phi + output_polarization) + * illumination_cos_theta + ) return S1_coef, S2_coef @@ -1802,27 +1858,57 @@ def _common_setup( """ xSize, ySize = self.get_xy_size(output_region, padding) - voxel_size = get_active_voxel_size() - scale = get_active_scale() + voxel_size = xp.asarray( + get_active_voxel_size(), + dtype=xp.float64, + ) + scale = xp.asarray( + get_active_scale(), + dtype=xp.float64, + ) - arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex) + arr = pad_image_to_fft( + xp.zeros((xSize, ySize), dtype=xp.complex128) + ) + position = _asarray_vector( + position, + dtype=xp.float64, + ) position = ( - np.array(position) + position * scale[: len(position)] * voxel_size[: len(position)] ) + wavelength = _asarray(wavelength, dtype=xp.float64) + refractive_index_medium = _asarray( + refractive_index_medium, + dtype=xp.float64, + ) + collection_angle = _asarray( + collection_angle, + dtype=xp.float64, + ) + working_distance = _asarray( + working_distance, + dtype=xp.float64, + ) + z = _asarray(z, dtype=xp.float64) z = z * voxel_size[2] * scale[2] + position_objective = _asarray_vector( + position_objective, + dtype=xp.float64, + ) - pupil_physical_size = working_distance * np.tan(collection_angle) * 2 + pupil_physical_size = working_distance * xp.tan(collection_angle) * 2 k = 2 * np.pi / wavelength * refractive_index_medium - relative_position = np.array( - ( + relative_position = xp.stack( + [ position_objective[0] - position[0], position_objective[1] - position[1], working_distance - z, - ) + ] ) return ( @@ -2008,20 +2094,20 @@ def _solve_geometric( ) ) - cos_phi_field = np.cos(phi_field) - sin_phi_field = np.sin(phi_field) + cos_phi_field = xp.cos(phi_field) + sin_phi_field = xp.sin(phi_field) x_farfield = ( position[0] + R3_field - * np.sqrt(1 - cos_theta_field**2) + * xp.sqrt(1 - cos_theta_field**2) * cos_phi_field / ratio ) y_farfield = ( position[1] + R3_field - * np.sqrt(1 - cos_theta_field**2) + * xp.sqrt(1 - cos_theta_field**2) * sin_phi_field / ratio ) @@ -2048,29 +2134,29 @@ def _solve_geometric( arr[pupil_mask] = ( -1j / (k * R3_field) - * np.exp(1j * k * R3_field) + * xp.exp(1j * k * R3_field) * (S2 * S2_coef + S1 * S1_coef) ) / amp_factor # For phase shift correction (a multiplication of the field # by exp(1j * k * z)). if phase_shift_correction: - arr *= np.exp(1j * k * z + 1j * np.pi / 2) + arr = arr * xp.exp(1j * k * z + 1j * np.pi / 2) # For partially coherent illumination. if coherence_length: - sigma = z * np.sqrt((coherence_length / z + 1) ** 2 - 1) + sigma = z * xp.sqrt((coherence_length / z + 1) ** 2 - 1) sigma = sigma * (offset_z / z) - mask = np.zeros_like(arr) - y, x = np.ogrid[ - -mask.shape[0] // 2 : mask.shape[0] // 2, - -mask.shape[1] // 2 : mask.shape[1] // 2, - ] - mask = np.exp(-0.5 * (x**2 + y**2) / ((sigma) ** 2)) + y = xp.arange(arr.shape[0], dtype=xp.float64) + x = xp.arange(arr.shape[1], dtype=xp.float64) + y = y - arr.shape[0] // 2 + x = x - arr.shape[1] // 2 + y, x = xp.meshgrid(y, x, indexing="ij") + mask = xp.exp(-0.5 * (x**2 + y**2) / ((sigma) ** 2)) arr = arr * mask - fourier_field = np.fft.fft2(arr) + fourier_field = xp.fft.fft2(arr) propagation_matrix = get_propagation_matrix( fourier_field.shape, @@ -2089,11 +2175,15 @@ def _solve_geometric( ), ) - fourier_field *= propagation_matrix * np.exp(-1j * k * offset_z) + fourier_field = ( + fourier_field + * propagation_matrix + * xp.exp(-1j * k * offset_z) + ) if return_fft: - return fourier_field[..., np.newaxis] - return np.fft.ifft2(fourier_field)[..., np.newaxis] + return fourier_field[..., None] + return xp.fft.ifft2(fourier_field)[..., None] def _solve_hybrid( self: MieScatterer, @@ -2230,20 +2320,20 @@ def _solve_hybrid( k, ) - cos_phi_field = np.cos(phi_field) - sin_phi_field = np.sin(phi_field) + cos_phi_field = xp.cos(phi_field) + sin_phi_field = xp.sin(phi_field) x_farfield = ( position[0] + R3_field - * np.sqrt(1 - cos_theta_field**2) + * xp.sqrt(1 - cos_theta_field**2) * cos_phi_field / ratio ) y_farfield = ( position[1] + R3_field - * np.sqrt(1 - cos_theta_field**2) + * xp.sqrt(1 - cos_theta_field**2) * sin_phi_field / ratio ) @@ -2261,19 +2351,19 @@ def _solve_hybrid( # For phase shift correction (a multiplication of the field # by exp(1j * k * z)). if phase_shift_correction: - arr *= np.exp(1j * k * z + 1j * np.pi / 2) + arr = arr * xp.exp(1j * k * z + 1j * np.pi / 2) # For partially coherent illumination. if coherence_length: - sigma = z * np.sqrt((coherence_length / z + 1) ** 2 - 1) + sigma = z * xp.sqrt((coherence_length / z + 1) ** 2 - 1) sigma = sigma * (offset_z / z) - mask = np.zeros_like(arr) - y, x = np.ogrid[ - -mask.shape[0] // 2 : mask.shape[0] // 2, - -mask.shape[1] // 2 : mask.shape[1] // 2, - ] - mask = np.exp(-0.5 * (x**2 + y**2) / ((sigma) ** 2)) + y = xp.arange(arr.shape[0], dtype=xp.float64) + x = xp.arange(arr.shape[1], dtype=xp.float64) + y = y - arr.shape[0] // 2 + x = x - arr.shape[1] // 2 + y, x = xp.meshgrid(y, x, indexing="ij") + mask = xp.exp(-0.5 * (x**2 + y**2) / ((sigma) ** 2)) arr = arr * mask if pupil is not None and len(pupil) > 0: @@ -2281,10 +2371,15 @@ def _solve_hybrid( c1 = arr.shape[1] // 2 h0 = pupil.shape[0] // 2 h1 = pupil.shape[1] // 2 - arr[c0 - h0 : c0 + h0, c1 - h1 : c1 + h1] *= pupil + pupil_mask = xp.ones_like(arr) + pupil_mask[c0 - h0 : c0 + h0, c1 - h1 : c1 + h1] = _asarray( + pupil, + dtype=arr.dtype, + ) + arr = arr * pupil_mask - fourier_field = np.fft.ifft2( - np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))) + fourier_field = xp.fft.ifft2( + xp.fft.fftshift(xp.fft.fft2(xp.fft.fftshift(arr))) ) propagation_matrix = get_propagation_matrix( @@ -2304,11 +2399,11 @@ def _solve_hybrid( ), ) - fourier_field *= propagation_matrix + fourier_field = fourier_field * propagation_matrix if return_fft: - return fourier_field[..., np.newaxis] - return np.fft.ifft2(fourier_field)[..., np.newaxis] + return fourier_field[..., None] + return xp.fft.ifft2(fourier_field)[..., None] class MieSphere(MieScatterer): @@ -2537,7 +2632,21 @@ def coeffs( """ - if not np.all(radius[1:] >= radius[:-1]): + radius_for_check = radius + if TORCH_AVAILABLE and torch.is_tensor(radius_for_check): + radius_for_check = radius_for_check.detach().cpu().numpy() + elif isinstance(radius_for_check, (list, tuple)): + radius_for_check = [ + item.detach().cpu().numpy() + if TORCH_AVAILABLE and torch.is_tensor(item) + else item + for item in radius_for_check + ] + + if not np.all( + np.asarray(radius_for_check)[1:] + >= np.asarray(radius_for_check)[:-1] + ): raise ValueError( "Radius of the shells of a stratified sphere should be " "monotonically increasing." @@ -2545,8 +2654,13 @@ def coeffs( def inner(L: int): return mie.stratified_coefficients( - np.array(refractive_index) / refractive_index_medium, - np.array(radius) + _asarray_vector( + refractive_index, + ) + / refractive_index_medium, + _asarray_vector( + radius, + ) * 2 * np.pi / wavelength From 364edf89731f32458b13e1b75c605a5e7db20960 Mon Sep 17 00:00:00 2001 From: edudc Date: Wed, 13 May 2026 11:38:02 +0200 Subject: [PATCH 06/13] made get_propagation_matrix backend-aware --- deeptrack/optical/holography.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/deeptrack/optical/holography.py b/deeptrack/optical/holography.py index 4f0432bef..73672766a 100644 --- a/deeptrack/optical/holography.py +++ b/deeptrack/optical/holography.py @@ -93,6 +93,7 @@ def get_propagation_matrix( import numpy as np +from deeptrack.backend import xp from deeptrack.backend.units import get_active_voxel_size from deeptrack import Feature @@ -144,7 +145,9 @@ def get_propagation_matrix( if pixel_size is None: pixel_size = get_active_voxel_size() - if np.isscalar(pixel_size): + if np.isscalar(pixel_size) or ( + hasattr(pixel_size, "ndim") and pixel_size.ndim == 0 + ): pixel_size = (pixel_size, pixel_size) px, py = pixel_size @@ -152,21 +155,26 @@ def get_propagation_matrix( k = 2 * np.pi / wavelength yr, xr, *_ = shape - x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2 - y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2 + x = xp.arange(0, xr, 1, dtype=xp.float64) - xr / 2 + (xr % 2) / 2 + y = xp.arange(0, yr, 1, dtype=xp.float64) - yr / 2 + (yr % 2) / 2 x = 2 * np.pi / px * x / xr y = 2 * np.pi / py * y / yr - KXk, KYk = np.meshgrid(x, y) - KXk = KXk.astype(complex) - KYk = KYk.astype(complex) + KXk_real, KYk_real = xp.meshgrid(x, y) + KXk = xp.astype(KXk_real, xp.complex128) + KYk = xp.astype(KYk_real, xp.complex128) - K = np.real(np.sqrt(1 - (KXk / k) ** 2 - (KYk / k) ** 2)) - C = np.fft.fftshift(((KXk / k) ** 2 + (KYk / k) ** 2 < 1) * 1.0) + K = xp.real(xp.sqrt(1 - (KXk / k) ** 2 - (KYk / k) ** 2)) + C = xp.fft.fftshift( + xp.astype( + ((KXk_real / k) ** 2 + (KYk_real / k) ** 2 < 1), + xp.float64, + ) + ) - return C * np.fft.fftshift( - np.exp(k * 1j * (to_z * (K - 1) - dx * KXk / k - dy * KYk / k)) + return C * xp.fft.fftshift( + xp.exp(k * 1j * (to_z * (K - 1) - dx * KXk / k - dy * KYk / k)) ) From 39dc0547bc53041613c6c06dc6d0cd0dcf4aa9a8 Mon Sep 17 00:00:00 2001 From: edudc Date: Wed, 13 May 2026 11:39:47 +0200 Subject: [PATCH 07/13] torch regression tests for Mie sphere autodiff, multi-field brightfield summation, propagation matrix autodiff, and Zernike coefficient gradients --- tests/test_aberrations.py | 17 +++++- tests/test_holography.py | 52 +++++++++++++--- tests/test_scatterers.py | 125 +++++++++++++++++++++++++++++++++++++- 3 files changed, 182 insertions(+), 12 deletions(-) diff --git a/tests/test_aberrations.py b/tests/test_aberrations.py index c0bd29901..a50960881 100644 --- a/tests/test_aberrations.py +++ b/tests/test_aberrations.py @@ -216,6 +216,21 @@ def testSphericalAberration_resolves(self): class TestAberrations_PyTorch(TestAberrations_NumPy): BACKEND = "torch" + def test_zero_zernike_coefficient_keeps_torch_gradient(self): + coefficient = torch.tensor( + 0.0, + dtype=torch.float64, + requires_grad=True, + ) + pupil = aberrations.Zernike(n=2, m=0, coefficient=coefficient) + image = self._make_optics(pupil)(self.particle).resolve() + + loss = image.sum() + loss.backward() + + self.assertIsNotNone(coefficient.grad) + self.assertTrue(torch.isfinite(coefficient.grad)) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_holography.py b/tests/test_holography.py index f08531c66..43b4f95f1 100644 --- a/tests/test_holography.py +++ b/tests/test_holography.py @@ -6,22 +6,54 @@ import numpy as np +from deeptrack.backend import TORCH_AVAILABLE, config from deeptrack.optical import holography +if TORCH_AVAILABLE: + import torch + + class TestOpticalFieldFunctions(unittest.TestCase): - - def test_get_propagation_matrix(self): - propagation_matrix = holography.get_propagation_matrix( - shape=(128, 128), - to_z=1.0, - pixel_size=0.1, - wavelength=0.65e-6, - dx=0, - dy=0 - ) + + def test_get_propagation_matrix(self): + with config.with_backend("numpy"): + propagation_matrix = holography.get_propagation_matrix( + shape=(128, 128), + to_z=1.0, + pixel_size=0.1, + wavelength=0.65e-6, + dx=0, + dy=0, + ) self.assertEqual(propagation_matrix.shape, (128, 128)) self.assertTrue(np.iscomplexobj(propagation_matrix)) + @unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.") + def test_get_propagation_matrix_torch_autodiff(self): + with config.with_backend("torch"): + wavelength = torch.tensor( + 0.65e-6, + dtype=torch.float64, + requires_grad=True, + ) + propagation_matrix = holography.get_propagation_matrix( + shape=(32, 32), + to_z=1.0, + pixel_size=0.1, + wavelength=wavelength, + dx=0, + dy=0, + ) + + self.assertIsInstance(propagation_matrix, torch.Tensor) + self.assertEqual(propagation_matrix.shape, (32, 32)) + self.assertTrue(torch.is_complex(propagation_matrix)) + + loss = torch.real(propagation_matrix).sum() + loss.backward() + self.assertIsNotNone(wavelength.grad) + self.assertTrue(torch.isfinite(wavelength.grad)) + def test_rescale(self): rescale_factor = 0.5 image = np.random.rand(128, 128, 2) diff --git a/tests/test_scatterers.py b/tests/test_scatterers.py index 2b260162d..08ef9d7a4 100644 --- a/tests/test_scatterers.py +++ b/tests/test_scatterers.py @@ -7,7 +7,7 @@ import numpy as np from deeptrack.backend import TORCH_AVAILABLE -from deeptrack.optical.optics import Fluorescence +from deeptrack.optical.optics import Brightfield, Fluorescence from deeptrack.optical import scatterers from tests import BackendTestBase @@ -593,6 +593,129 @@ class TestScatterers_Torch(TestScatterers_NumPy): class TestMath_TorchOnly(BackendTestBase): BACKEND = "torch" + def _torch_mie_sphere(self, mode, radius, refractive_index, **kwargs): + params = dict( + radius=radius, + refractive_index=refractive_index, + position=(16, 16), + position_unit="pixel", + wavelength=680e-9, + refractive_index_medium=1.33, + NA=0.7, + output_region=(0, 0, 32, 32), + padding=(0, 0, 0, 0), + input_polarization=0.0, + output_polarization=0.0, + return_fft=False, + L=5, + collection_angle=0.3, + offset_z=1e-5, + mode=mode, + ) + params.update(kwargs) + return scatterers.MieSphere(**params) + + def test_mie_sphere_resolves_with_torch_autodiff(self): + for mode in ("geometric", "hybrid"): + with self.subTest(mode=mode): + radius = torch.tensor( + 0.5e-6, + dtype=torch.float64, + requires_grad=True, + ) + refractive_index = torch.tensor( + 1.45, + dtype=torch.float64, + requires_grad=True, + ) + + out = self._torch_mie_sphere( + mode, + radius, + refractive_index, + ).resolve() + + self.assertIsInstance(out.array, torch.Tensor) + self.assertEqual(out.shape, (32, 32, 1)) + self.assertTrue(torch.is_complex(out.array)) + self.assertTrue(torch.isfinite(out.array.real).all()) + self.assertTrue(torch.isfinite(out.array.imag).all()) + self.assertGreater( + float(torch.abs(out.array).sum().detach()), + 0, + ) + self.assertTrue(out.array.requires_grad) + + loss = torch.abs(out.array).sum() + loss.backward() + + self.assertIsNotNone(radius.grad) + self.assertIsNotNone(refractive_index.grad) + self.assertTrue(torch.isfinite(radius.grad)) + self.assertTrue(torch.isfinite(refractive_index.grad)) + self.assertGreater(abs(float(radius.grad)), 0) + self.assertGreater(abs(float(refractive_index.grad)), 0) + + def test_mie_sphere_brightfield_sums_multiple_torch_fields(self): + radius_1 = torch.tensor( + 0.45e-6, + dtype=torch.float64, + requires_grad=True, + ) + radius_2 = torch.tensor( + 0.55e-6, + dtype=torch.float64, + requires_grad=True, + ) + + common = dict( + refractive_index=1.45, + input_polarization=0.0, + output_polarization=0.0, + L=5, + collection_angle=0.3, + offset_z=1e-5, + mode="hybrid", + ) + sample = scatterers.MieSphere( + radius=radius_1, + position=(14, 16), + position_unit="pixel", + **common, + ) >> scatterers.MieSphere( + radius=radius_2, + position=(18, 16), + position_unit="pixel", + **common, + ) + microscope = Brightfield( + NA=0.7, + wavelength=680e-9, + resolution=1e-6, + magnification=10, + output_region=(0, 0, 32, 32), + padding=(4, 4, 4, 4), + return_field=True, + ) + + image = microscope(sample).resolve() + + self.assertIsInstance(image, torch.Tensor) + self.assertEqual(image.shape, (32, 32, 1)) + self.assertTrue(torch.is_complex(image)) + self.assertTrue(torch.isfinite(image.real).all()) + self.assertTrue(torch.isfinite(image.imag).all()) + + loss = torch.abs(image).sum() + loss.backward() + + self.assertIsNotNone(radius_1.grad) + self.assertIsNotNone(radius_2.grad) + self.assertTrue(torch.isfinite(radius_1.grad)) + self.assertTrue(torch.isfinite(radius_2.grad)) + self.assertGreater(abs(float(radius_1.grad)), 0) + self.assertGreater(abs(float(radius_2.grad)), 0) + def test_point_particle_intensity_gradient(self): # --- PointParticle intensity optimization --- From fca7b0417937f72722bd6959c2d7cdca650015d1 Mon Sep 17 00:00:00 2001 From: edudc Date: Wed, 13 May 2026 14:25:29 +0200 Subject: [PATCH 08/13] the derivative through torch.arccos(cos_theta) blowed up for zero illumination angle; this now uses the equivalent cos_theta branch --- deeptrack/optical/scatterers.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/deeptrack/optical/scatterers.py b/deeptrack/optical/scatterers.py index 5a2f67f2d..4c23e8a85 100644 --- a/deeptrack/optical/scatterers.py +++ b/deeptrack/optical/scatterers.py @@ -1626,9 +1626,12 @@ def _plane_in_polar_coords_geometric( R3 = xp.sqrt(R2_squared + Z**2) cos_theta = Z / R3 - illumination_cos_theta = xp.cos( - xp.acos(cos_theta) + illumination_angle - ) + if float(illumination_angle) == 0: + illumination_cos_theta = cos_theta + else: + illumination_cos_theta = xp.cos( + xp.acos(cos_theta) + illumination_angle + ) phi = xp.atan2(Y, X) return R3, cos_theta, illumination_cos_theta, phi @@ -1693,9 +1696,12 @@ def _plane_in_polar_coords_hybrid( xp.zeros_like(cos_theta), ) - illumination_cos_theta = xp.cos( - xp.acos(cos_theta) + illumination_angle - ) + if float(illumination_angle) == 0: + illumination_cos_theta = cos_theta + else: + illumination_cos_theta = xp.cos( + xp.acos(cos_theta) + illumination_angle + ) phi = xp.atan2(Y, X) return R3, cos_theta, illumination_cos_theta, phi, pupil_mask From 9a808714ea8a9662447b4752535f423753ac57c5 Mon Sep 17 00:00:00 2001 From: edudc Date: Mon, 25 May 2026 15:29:59 +0200 Subject: [PATCH 09/13] Pint triggers some numpy conversion. Added a path for tensors so that they don't go through Pint. --- deeptrack/backend/units.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/deeptrack/backend/units.py b/deeptrack/backend/units.py index 5d40397da..96d3f6e44 100644 --- a/deeptrack/backend/units.py +++ b/deeptrack/backend/units.py @@ -386,6 +386,22 @@ def convert( default_unit, desired_unit = value + if TORCH_AVAILABLE and torch.is_tensor(quantity): + factor = (1 * default_unit).to(desired_unit).to_reduced_units() + factor = factor.magnitude + kwargs[key] = quantity * factor + continue + + if ( + TORCH_AVAILABLE + and isinstance(quantity, (list, tuple)) + and any(torch.is_tensor(item) for item in quantity) + ): + factor = (1 * default_unit).to(desired_unit).to_reduced_units() + factor = factor.magnitude + kwargs[key] = type(quantity)(item * factor for item in quantity) + continue + # Convert non-quantities to quantities in default units if not isinstance(quantity, Quantity): quantity = quantity * default_unit From 9b6f833cc1e9056ca22945776817849f4b60d164 Mon Sep 17 00:00:00 2001 From: edudc Date: Mon, 25 May 2026 16:21:49 +0200 Subject: [PATCH 10/13] modifications for learnable x, y, resolution, na, magnification, wavelength --- deeptrack/optical/optics.py | 189 ++++++++++++++++++++++---------- deeptrack/optical/scatterers.py | 35 ++++-- tests/test_scatterers.py | 85 ++++++++++++++ 3 files changed, 239 insertions(+), 70 deletions(-) diff --git a/deeptrack/optical/optics.py b/deeptrack/optical/optics.py index 7e9cb7eb3..6aac98007 100644 --- a/deeptrack/optical/optics.py +++ b/deeptrack/optical/optics.py @@ -312,9 +312,23 @@ def get( if np.array(_upscale_given_by_optics).size == 1: _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 + voxel_size_for_context = additional_sample_kwargs["voxel_size"] + if TORCH_AVAILABLE and torch.is_tensor(voxel_size_for_context): + voxel_size_for_context = ( + voxel_size_for_context.detach().cpu().numpy() + ) + elif TORCH_AVAILABLE and isinstance( + voxel_size_for_context, + (list, tuple), + ): + voxel_size_for_context = type(voxel_size_for_context)( + item.detach().cpu().item() if torch.is_tensor(item) else item + for item in voxel_size_for_context + ) + with u.context( create_context( - *additional_sample_kwargs["voxel_size"], + *voxel_size_for_context, *_upscale_given_by_optics, ) ): @@ -629,7 +643,11 @@ def get_voxel_size( props = self._normalize( resolution=resolution, magnification=magnification ) - return np.ones((3,)) * props["resolution"] / props["magnification"] + return ( + xp.ones((3,), dtype=xp.float64) + * props["resolution"] + / props["magnification"] + ) def get_pixel_size( resolution: ( @@ -711,8 +729,20 @@ def _process_properties( NA = propertydict["NA"] wavelength = propertydict["wavelength"] - voxel_size = get_active_voxel_size() - radius = NA / wavelength * np.array(voxel_size) + voxel_size = propertydict.get("voxel_size", get_active_voxel_size()) + if TORCH_AVAILABLE and torch.is_tensor(NA): + NA = NA.detach().cpu().numpy() + if TORCH_AVAILABLE and torch.is_tensor(wavelength): + wavelength = wavelength.detach().cpu().numpy() + if TORCH_AVAILABLE and torch.is_tensor(voxel_size): + voxel_size = voxel_size.detach().cpu().numpy() + elif TORCH_AVAILABLE and isinstance(voxel_size, (list, tuple)): + voxel_size = [ + item.detach().cpu().item() if torch.is_tensor(item) else item + for item in voxel_size + ] + + radius = NA / wavelength * np.array(voxel_size, dtype=float) if np.any(radius[:2] > 0.5): required_upscale = np.max(np.ceil(radius[:2] * 2)) @@ -890,17 +920,34 @@ def _pupil_torch( semantics. """ - # Resolve device - if isinstance(defocus, torch.Tensor): - device = defocus.device - complex_dtype = ( - defocus.dtype - if defocus.dtype in (torch.complex64, torch.complex128) - else torch.complex64 + voxel_size = kwargs.get("voxel_size", get_active_voxel_size()) + + tensor_refs = [ + value + for value in (defocus, NA, wavelength, refractive_index_medium) + if torch.is_tensor(value) + ] + if torch.is_tensor(voxel_size): + tensor_refs.append(voxel_size) + elif isinstance(voxel_size, (list, tuple)): + tensor_refs.extend( + value for value in voxel_size if torch.is_tensor(value) ) - else: - device = torch.device("cpu") - complex_dtype = torch.complex64 + + device = kwargs.get("device") or ( + tensor_refs[0].device if tensor_refs else torch.device("cpu") + ) + real_dtype = ( + torch.float64 + if any( + value.dtype in (torch.float64, torch.complex128) + for value in tensor_refs + ) + else torch.float32 + ) + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) # shape -> (H, W) following current usage where shape[0] is x-axis length shape_arr = np.array(shape, dtype=int) @@ -910,18 +957,37 @@ def _pupil_torch( H = int(shape_arr[0]) W = int(shape_arr[1]) - voxel_size_np = np.array( - get_active_voxel_size(), dtype=float - ) # (vx, vy, vz) - # Use python floats for constants; this is fine for differentiability - # w.r.t. volume - # If you ever want gradients w.r.t voxel_size, you’d pass it as - # torch.Tensor. - vx, vy, vz = ( - float(voxel_size_np[0]), - float(voxel_size_np[1]), - float(voxel_size_np[2]), + if isinstance(voxel_size, (list, tuple)): + voxel_size = torch.stack( + [ + torch.as_tensor( + value, + device=device, + dtype=real_dtype, + ) + for value in voxel_size + ] + ) + else: + voxel_size = torch.as_tensor( + voxel_size, + device=device, + dtype=real_dtype, + ) + + NA = torch.as_tensor(NA, device=device, dtype=real_dtype) + wavelength = torch.as_tensor( + wavelength, + device=device, + dtype=real_dtype, ) + refractive_index_medium = torch.as_tensor( + refractive_index_medium, + device=device, + dtype=real_dtype, + ) + + vx, vy, vz = voxel_size[0], voxel_size[1], voxel_size[2] # Pupil radius Rx = (NA / wavelength) * vx @@ -929,15 +995,6 @@ def _pupil_torch( x_radius = Rx * H y_radius = Ry * W - # Build coordinates exactly like NumPy: - # np.linspace(-(N/2), N/2 - 1, N) / radius + 1e-8 - # Use float for coordinate grid to reduce artifacts - real_dtype = ( - torch.float32 - if complex_dtype == torch.complex64 - else torch.float64 - ) - x = ( torch.linspace( -H / 2.0, @@ -946,7 +1003,7 @@ def _pupil_torch( device=device, dtype=real_dtype, ) - / float(x_radius) + / x_radius + 1e-8 ) @@ -958,7 +1015,7 @@ def _pupil_torch( device=device, dtype=real_dtype, ) - / float(y_radius) + / y_radius + 1e-8 ) @@ -970,33 +1027,41 @@ def _pupil_torch( pupil_function = (RHO.real < 1.0).to(complex_dtype) - k0 = 2.0 * np.pi * float(refractive_index_medium) / float(wavelength) - alpha = (float(NA) / float(refractive_index_medium)) ** 2 + k0 = 2.0 * np.pi * refractive_index_medium / wavelength + alpha = (NA / refractive_index_medium) ** 2 - inside = 1.0 - alpha * RHO # complex - sqrt_term = torch.sqrt(inside.to(complex_dtype)) + # inside = 1.0 - alpha * RHO # complex + # sqrt_term = torch.sqrt(inside.to(complex_dtype)) - z_shift = (k0 * float(vz)) * sqrt_term # complex + # z_shift = (k0 * float(vz)) * sqrt_term # complex - # Torch equivalent: - z_shift = torch.where( - z_shift.imag.abs() > 1e-12, - torch.zeros_like(z_shift), - z_shift, - ) + # # Torch equivalent: + # z_shift = torch.where( + # z_shift.imag.abs() > 1e-12, + # torch.zeros_like(z_shift), + # z_shift, + # ) - # nan_to_num equivalent - z_shift = torch.nan_to_num(z_shift) + # # nan_to_num equivalent + # z_shift = torch.nan_to_num(z_shift) - # defocus reshape (-1,1,1) - if isinstance(defocus, torch.Tensor): - defocus_t = defocus.to(device=device, dtype=real_dtype) - else: - defocus_t = torch.as_tensor( - defocus, device=device, dtype=real_dtype - ) + # torch.nan_to_num on complex tensors does not support autograd + # workaround: - defocus_t = defocus_t.reshape(-1, 1, 1) + inside = 1.0 - alpha * RHO.real + inside = torch.where( + inside >= 0, + inside, + torch.zeros_like(inside), + ) + z_shift = (k0 * vz) * torch.sqrt(inside).to(complex_dtype) + + # defocus reshape (-1,1,1) + defocus_t = torch.as_tensor( + defocus, + device=device, + dtype=real_dtype, + ).reshape(-1, 1, 1) # broadcast z_shift to (Z,H,W) z_shift_3d = defocus_t * z_shift.unsqueeze(0) @@ -1014,7 +1079,9 @@ def _pupil_torch( # move it to torch) elif isinstance(pupil_feat, np.ndarray): pf = torch.as_tensor( - pupil_feat, device=device, dtype=pupil_function.dtype + pupil_feat, + device=device, + dtype=pupil_function.dtype, ) pupil_function = pupil_function * pf @@ -1928,7 +1995,11 @@ def get( volume = pad_image_to_fft(padded_volume, axes=(0, 1)) - voxel_size = get_active_voxel_size() + voxel_size = kwargs.get("voxel_size", get_active_voxel_size()) + if self.get_backend() == "torch" and not torch.is_tensor( + voxel_size + ): + voxel_size = xp.asarray(voxel_size, dtype=xp.float64) pupils = [ self._pupil( diff --git a/deeptrack/optical/scatterers.py b/deeptrack/optical/scatterers.py index 4c23e8a85..aa471842e 100644 --- a/deeptrack/optical/scatterers.py +++ b/deeptrack/optical/scatterers.py @@ -448,12 +448,9 @@ def _process_and_get( Positional arguments passed to the method. Not used in this implementation. voxel_size: array - Voxel size supplied by the feature pipeline. In practice, - scatterers use the active optics configuration - (`get_active_voxel_size()`) to ensure that geometry evaluation is - consistent with the current imaging context. This argument is - considered framework-internal and is not intended as a user-facing - override. + Voxel size supplied by the feature pipeline. Field scatterers use + this value directly; volume scatterers use the active optics + context to keep geometry evaluation aligned with upsampling. upsample: int Geometry supersampling factor for volume-based scatterers. Ignored by field-based scatterers. @@ -482,7 +479,10 @@ def _process_and_get( + "Optics.upscale != 1." ) - voxel_size = xp.asarray(get_active_voxel_size(), dtype=float) + if isinstance(self, FieldScatterer) and voxel_size is not None: + voxel_size = _asarray(voxel_size, dtype=xp.float64) + else: + voxel_size = xp.asarray(get_active_voxel_size(), dtype=float) apply_supersampling = upsample > 1 and isinstance( self, VolumeScatterer @@ -1452,12 +1452,17 @@ def _process_properties( radius_for_l = properties["radius"] if TORCH_AVAILABLE and torch.is_tensor(radius_for_l): radius_for_l = radius_for_l.detach().cpu().numpy() + wavelength_for_l = properties["wavelength"] + if TORCH_AVAILABLE and torch.is_tensor(wavelength_for_l): + wavelength_for_l = ( + wavelength_for_l.detach().cpu().numpy() + ) v = ( 2 * np.pi * np.max(radius_for_l) - / properties["wavelength"] + / wavelength_for_l ) properties["L"] = int(np.floor((v + 4 * (v ** (1 / 3)) + 1))) @@ -1487,10 +1492,13 @@ def _process_properties( collection_angle, dtype=xp.float64, ) + voxel_size = properties.get("voxel_size") + if voxel_size is None: + voxel_size = get_active_voxel_size() properties["offset_z"] = ( min_edge_size * 0.45 - * min(get_active_voxel_size()[:2]) + * xp.min(_asarray(voxel_size, dtype=xp.float64)[:2]) / xp.tan(collection_angle) ) return properties @@ -1814,6 +1822,7 @@ def _mie_scattering( def _common_setup( self: MieScatterer, position: tuple[float, float, float], + voxel_size: np.ndarray, padding: tuple[int, int, int, int], output_region: tuple[int, int, int, int], wavelength: float, @@ -1837,6 +1846,8 @@ def _common_setup( ---------- position: tuple[float, float, float] The position of the particle in (x, y, z) coordinates. + voxel_size: np.ndarray + The physical voxel size in meters. padding: int The padding applied to the output region. output_region: tuple[int, int] @@ -1864,8 +1875,8 @@ def _common_setup( """ xSize, ySize = self.get_xy_size(output_region, padding) - voxel_size = xp.asarray( - get_active_voxel_size(), + voxel_size = _asarray( + voxel_size, dtype=xp.float64, ) scale = xp.asarray( @@ -2079,6 +2090,7 @@ def _solve_geometric( relative_position, ) = self._common_setup( position, + voxel_size, padding, output_region, wavelength, @@ -2300,6 +2312,7 @@ def _solve_hybrid( relative_position, ) = self._common_setup( position, + voxel_size, padding, output_region, wavelength, diff --git a/tests/test_scatterers.py b/tests/test_scatterers.py index 08ef9d7a4..adcbfd16c 100644 --- a/tests/test_scatterers.py +++ b/tests/test_scatterers.py @@ -3,6 +3,7 @@ # sys.path.append(".") # Adds the module to path import unittest +import warnings import numpy as np @@ -716,6 +717,90 @@ def test_mie_sphere_brightfield_sums_multiple_torch_fields(self): self.assertGreater(abs(float(radius_1.grad)), 0) self.assertGreater(abs(float(radius_2.grad)), 0) + def test_mie_sphere_brightfield_autodiff_learnable_parameters(self): + cases = [ + ("x", 14.25, "sample"), + ("y", 16.75, "sample"), + ("resolution", 1.0e-6, "optics"), + ("NA", 0.7, "optics"), + ("magnification", 10.0, "optics"), + ("wavelength", 680e-9, "optics"), + ("refractive_index_medium", 1.33, "optics"), + ] + + for name, value, owner in cases: + with self.subTest(parameter=name): + parameter = torch.tensor( + value, + dtype=torch.float64, + requires_grad=True, + ) + + sample_kwargs = dict( + radius=0.5e-6, + refractive_index=1.45, + position=(14.25, 16.75), + position_unit="pixel", + input_polarization=0.0, + output_polarization=0.0, + L=5, + collection_angle=0.3, + offset_z=1e-5, + mode="hybrid", + ) + optics_kwargs = dict( + NA=0.7, + wavelength=680e-9, + refractive_index_medium=1.33, + resolution=1e-6, + magnification=10, + output_region=(0, 0, 32, 32), + padding=(4, 4, 4, 4), + return_field=True, + ) + + if owner == "sample" and name == "x": + sample_kwargs["position"] = (parameter, 16.75) + elif owner == "sample" and name == "y": + sample_kwargs["position"] = (14.25, parameter) + else: + optics_kwargs[name] = parameter + if name == "NA": + sample_kwargs.pop("collection_angle") + sample_kwargs.pop("offset_z") + + sample = scatterers.MieSphere(**sample_kwargs) + microscope = Brightfield(**optics_kwargs) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + image = microscope(sample).resolve() + + tensor_warning = ( + "Converting a tensor with requires_grad=True to a scalar" + ) + self.assertFalse( + any(tensor_warning in str(w.message) for w in caught) + ) + self.assertIsInstance(image, torch.Tensor) + self.assertTrue(image.requires_grad) + self.assertTrue(torch.isfinite(image.real).all()) + self.assertTrue(torch.isfinite(image.imag).all()) + + weights = torch.linspace( + 0.5, + 1.5, + image.numel(), + dtype=image.real.dtype, + device=image.device, + ).reshape(image.shape) + loss = (torch.abs(image) * weights).sum() + loss.backward() + + self.assertIsNotNone(parameter.grad) + self.assertTrue(torch.isfinite(parameter.grad)) + self.assertGreater(abs(float(parameter.grad)), 0) + def test_point_particle_intensity_gradient(self): # --- PointParticle intensity optimization --- From fa4761e738e01bd0a7ce59f357b87d4b6ed21e80 Mon Sep 17 00:00:00 2001 From: edudc Date: Thu, 28 May 2026 14:44:55 +0200 Subject: [PATCH 11/13] only bypass Pint for torch tensors that actually require grad --- deeptrack/backend/units.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deeptrack/backend/units.py b/deeptrack/backend/units.py index 96d3f6e44..8e2a1fa12 100644 --- a/deeptrack/backend/units.py +++ b/deeptrack/backend/units.py @@ -386,7 +386,11 @@ def convert( default_unit, desired_unit = value - if TORCH_AVAILABLE and torch.is_tensor(quantity): + if ( + TORCH_AVAILABLE + and torch.is_tensor(quantity) + and quantity.requires_grad + ): factor = (1 * default_unit).to(desired_unit).to_reduced_units() factor = factor.magnitude kwargs[key] = quantity * factor @@ -395,7 +399,10 @@ def convert( if ( TORCH_AVAILABLE and isinstance(quantity, (list, tuple)) - and any(torch.is_tensor(item) for item in quantity) + and any( + torch.is_tensor(item) and item.requires_grad + for item in quantity + ) ): factor = (1 * default_unit).to(desired_unit).to_reduced_units() factor = factor.magnitude From e7a927a3cffa11c375384511e536b9110307f131 Mon Sep 17 00:00:00 2001 From: edudc Date: Thu, 28 May 2026 15:00:16 +0200 Subject: [PATCH 12/13] differentiable paths use explicit kwargs["voxel_size"], so we must scale that explicit value --- deeptrack/optical/optics.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/deeptrack/optical/optics.py b/deeptrack/optical/optics.py index 6aac98007..7e352865f 100644 --- a/deeptrack/optical/optics.py +++ b/deeptrack/optical/optics.py @@ -335,6 +335,16 @@ def get( upscale = np.round(get_active_scale()) + voxel_size = additional_sample_kwargs["voxel_size"] + if TORCH_AVAILABLE and torch.is_tensor(voxel_size): + additional_sample_kwargs["voxel_size"] = voxel_size / torch.as_tensor( + upscale, + device=voxel_size.device, + dtype=voxel_size.dtype, + ) + else: + additional_sample_kwargs["voxel_size"] = get_active_voxel_size() + output_region = additional_sample_kwargs.pop("output_region") additional_sample_kwargs["output_region"] = [ int(o * upsc) @@ -358,6 +368,9 @@ def get( self._objective.padding.set_value( additional_sample_kwargs["padding"] ) + self._objective.voxel_size.set_value( + additional_sample_kwargs["voxel_size"] + ) propagate_data_to_dependencies( self._sample, From 19579348627c4441cd230efbe9d97a25bcd6a2e3 Mon Sep 17 00:00:00 2001 From: edudc Date: Thu, 28 May 2026 15:11:01 +0200 Subject: [PATCH 13/13] tutorial for position optimization --- .../DTDV431_mie_position_optimization.ipynb | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 tutorials/4-developers/DTDV431_mie_position_optimization.ipynb diff --git a/tutorials/4-developers/DTDV431_mie_position_optimization.ipynb b/tutorials/4-developers/DTDV431_mie_position_optimization.ipynb new file mode 100644 index 000000000..aeed7bb71 --- /dev/null +++ b/tutorials/4-developers/DTDV431_mie_position_optimization.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mie position optimization with torch autodiff\n", + "\n", + "This notebook fits the `(x, y)` position of a single `MieSphere` by backpropagating through a `Brightfield` image. It is intentionally small and uses fixed `L`, `collection_angle`, and `offset_z` so the optimized variables are continuous.\n", + "\n", + "The important detail is calling `pipeline.update()()` inside the optimization loop. DeepTrack caches feature outputs, so `update()` is needed after each optimizer step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from deeptrack.backend import config\n", + "from deeptrack.optical.optics import Brightfield\n", + "from deeptrack.optical import scatterers\n", + "\n", + "torch.manual_seed(0)\n", + "torch.set_default_dtype(torch.float64)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "true_x, true_y = 15.5, 17.25\n", + "\n", + "with config.with_backend(\"torch\"):\n", + " optics = Brightfield(\n", + " NA=0.7,\n", + " wavelength=680e-9,\n", + " refractive_index_medium=1.33,\n", + " resolution=1e-6,\n", + " magnification=10,\n", + " output_region=(0, 0, 32, 32),\n", + " padding=(4, 4, 4, 4),\n", + " )\n", + "\n", + " mie_kwargs = dict(\n", + " radius=0.5e-6,\n", + " refractive_index=1.45,\n", + " position_unit=\"pixel\",\n", + " input_polarization=0.0,\n", + " output_polarization=0.0,\n", + " L=5,\n", + " collection_angle=0.3,\n", + " offset_z=1e-5,\n", + " mode=\"hybrid\",\n", + " )\n", + "\n", + " target_sample = scatterers.MieSphere(\n", + " position=(true_x, true_y),\n", + " **mie_kwargs,\n", + " )\n", + " target = optics(target_sample).update()().detach()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with config.with_backend(\"torch\"):\n", + " x = torch.tensor(13.0, requires_grad=True)\n", + " y = torch.tensor(14.5, requires_grad=True)\n", + "\n", + " fitted_sample = scatterers.MieSphere(\n", + " position=(x, y),\n", + " **mie_kwargs,\n", + " )\n", + " pipeline = optics(fitted_sample)\n", + " optimizer = torch.optim.Adam([x, y], lr=0.2)\n", + "\n", + " history = []\n", + " for step in range(50):\n", + " optimizer.zero_grad()\n", + " image = pipeline.update()()\n", + " loss = ((image - target) ** 2).mean()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " history.append(\n", + " (step, loss.item(), x.detach().item(), y.detach().item())\n", + " )\n", + "\n", + "history = np.array(history)\n", + "print(f\"true position: ({true_x:.2f}, {true_y:.2f})\")\n", + "print(f\"fitted position: ({history[-1, 2]:.2f}, {history[-1, 3]:.2f})\")\n", + "print(f\"final loss: {history[-1, 1]:.3e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(9, 3.5))\n", + "\n", + "axes[0].semilogy(history[:, 0], history[:, 1])\n", + "axes[0].set_xlabel(\"step\")\n", + "axes[0].set_ylabel(\"MSE loss\")\n", + "\n", + "axes[1].plot(history[:, 2], history[:, 3], marker=\".\", label=\"fit\")\n", + "axes[1].scatter([true_x], [true_y], c=\"tab:red\", label=\"target\")\n", + "axes[1].set_xlabel(\"x [px]\")\n", + "axes[1].set_ylabel(\"y [px]\")\n", + "axes[1].axis(\"equal\")\n", + "axes[1].legend()\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fitted_image = pipeline.update()().detach()\n", + "residual = fitted_image - target\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(10, 3))\n", + "for ax, array, title in zip(\n", + " axes,\n", + " (target, fitted_image, residual),\n", + " (\"target\", \"fitted\", \"residual\"),\n", + "):\n", + " ax.imshow(array[..., 0].cpu().numpy(), cmap=\"gray\")\n", + " ax.set_title(title)\n", + " ax.axis(\"off\")\n", + "\n", + "fig.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}