Skip to content

Commit 2412a4e

Browse files
jobovyclaude
andcommitted
Address PR review: improve tests, simplify unit handling, optimize vectorization
- Tighten mockFlatWeaklyTDNonaxiM3 Jacobi tolerance to -6.0 - Replace isfinite-only tests with value checks: below/above grid tests compare against static snapshots and monopole decay; orbit tests check energy conservation or C/Python consistency - Replace test_time_dependent_nonaxi_c_orbit_limited_grid and test_time_dependent_nonaxi_c_dxdv with Liouville phase-space volume conservation test (det(Jacobian) == 1) - Add energy dissipation check to dynamical friction test - Move test_time_dependent_quantity_density_warning to test_quantity.py - Simplify _parse_density: use _density_has_units instead of duplicate try/except blocks for unit detection - Optimize _compute_rho_lm_timedep: share broadcasting arrays between vectorization check and actual computation - Add dx broadcasting shape comment in _quintic_hermite_ppoly_coeffs - Add thread-safety note to P_buf/dP_buf Legendre buffer comment in C Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 37f599d commit 2412a4e

5 files changed

Lines changed: 235 additions & 140 deletions

File tree

galpy/potential/MultipoleExpansionPotential.py

Lines changed: 43 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -565,59 +565,43 @@ def _parse_density(dens, ro, vo):
565565
except TypeError:
566566
numOfParam = 1
567567
# Handle astropy units
568-
if has_t and _APY_LOADED:
569-
# Check if time-dependent density returns Quantity and warn
570-
param = [1.0] * numOfParam
571-
try:
572-
dens(*param, t=0.0).to(units.kg / units.m**3)
573-
except (AttributeError, units.UnitConversionError, TypeError):
574-
pass
575-
else:
576-
import warnings
568+
if has_t and MultipoleExpansionPotential._density_has_units(dens):
569+
import warnings
577570

578-
from ..util import galpyWarning
571+
from ..util import galpyWarning
579572

580-
warnings.warn(
581-
"Time-dependent density appears to return an astropy "
582-
"Quantity. Unit conversion is not supported for "
583-
"time-dependent densities; pass the density in internal "
584-
"units (1/ro^3 * vo^2 / (4 pi G)) instead.",
585-
galpyWarning,
573+
warnings.warn(
574+
"Time-dependent density appears to return an astropy "
575+
"Quantity. Unit conversion is not supported for "
576+
"time-dependent densities; pass the density in internal "
577+
"units (1/ro^3 * vo^2 / (4 pi G)) instead.",
578+
galpyWarning,
579+
)
580+
if not has_t and MultipoleExpansionPotential._density_has_units(dens):
581+
raw_dens = dens
582+
if numOfParam == 1:
583+
return (
584+
lambda R, z, phi: conversion.parse_dens(
585+
raw_dens(numpy.sqrt(R**2 + z**2)),
586+
ro=ro,
587+
vo=vo,
588+
),
589+
False,
590+
)
591+
elif numOfParam == 2:
592+
return (
593+
lambda R, z, phi: conversion.parse_dens(
594+
raw_dens(R, z), ro=ro, vo=vo
595+
),
596+
False,
586597
)
587-
if not has_t and _APY_LOADED:
588-
param = [1.0] * numOfParam
589-
_dens_unit_output = False
590-
try:
591-
dens(*param).to(units.kg / units.m**3)
592-
except (AttributeError, units.UnitConversionError):
593-
pass
594598
else:
595-
_dens_unit_output = True
596-
if _dens_unit_output:
597-
raw_dens = dens
598-
if numOfParam == 1:
599-
return (
600-
lambda R, z, phi: conversion.parse_dens(
601-
raw_dens(numpy.sqrt(R**2 + z**2)),
602-
ro=ro,
603-
vo=vo,
604-
),
605-
False,
606-
)
607-
elif numOfParam == 2:
608-
return (
609-
lambda R, z, phi: conversion.parse_dens(
610-
raw_dens(R, z), ro=ro, vo=vo
611-
),
612-
False,
613-
)
614-
else:
615-
return (
616-
lambda R, z, phi: conversion.parse_dens(
617-
raw_dens(R, z, phi), ro=ro, vo=vo
618-
),
619-
False,
620-
)
599+
return (
600+
lambda R, z, phi: conversion.parse_dens(
601+
raw_dens(R, z, phi), ro=ro, vo=vo
602+
),
603+
False,
604+
)
621605
# Wrap based on number of spatial params
622606
if has_t:
623607
if numOfParam == 1:
@@ -812,15 +796,15 @@ def _compute_rho_lm_timedep(
812796
# Axisymmetric: no phi integral needed
813797
rho_cos_all = numpy.zeros((Nt, Nr, L, 1))
814798
rho_sin_all = numpy.zeros((Nt, Nr, L, 1))
799+
# Preallocate broadcasting arrays for vectorized path
800+
R_col = rgrid[:, numpy.newaxis] # (Nr, 1)
801+
t_row = tgrid[numpy.newaxis, :] # (1, Nt)
815802
# Try fully vectorized: evaluate density at all (r, t) at once
816803
_vectorized = True
817804
try:
818-
R_2d = rgrid[:, numpy.newaxis] # (Nr, 1)
819-
z_2d = numpy.zeros((Nr, 1))
820-
t_2d = tgrid[numpy.newaxis, :] # (1, Nt)
821805
ct = ct_nodes[0]
822806
sintheta = numpy.sqrt(1.0 - ct**2)
823-
test = dens_func(R_2d * sintheta, R_2d * ct, 0.0, t_2d)
807+
test = dens_func(R_col * sintheta, R_col * ct, 0.0, t_row)
824808
if numpy.shape(test) != (Nr, Nt):
825809
_vectorized = False
826810
except (TypeError, ValueError):
@@ -829,14 +813,13 @@ def _compute_rho_lm_timedep(
829813
ct = ct_nodes[ict]
830814
wt = ct_weights[ict]
831815
sintheta = numpy.sqrt(1.0 - ct**2)
832-
R_col = rgrid[:, numpy.newaxis] # (Nr, 1)
833816
if _vectorized:
834817
# (Nr, Nt) via broadcasting
835818
rho_spatial = dens_func(
836819
R_col * sintheta,
837820
R_col * ct,
838821
0.0,
839-
tgrid[numpy.newaxis, :],
822+
t_row,
840823
).T # -> (Nt, Nr)
841824
else:
842825
rho_spatial = numpy.zeros((Nt, Nr))
@@ -860,24 +843,22 @@ def _compute_rho_lm_timedep(
860843
sin_mphi = numpy.sin(numpy.outer(phi_nodes, m_arr)) # (phi_order, M)
861844
rho_cos_all = numpy.zeros((Nt, Nr, L, M))
862845
rho_sin_all = numpy.zeros((Nt, Nr, L, M))
846+
# Preallocate broadcasting arrays for vectorized path
847+
R_3d = rgrid[:, numpy.newaxis, numpy.newaxis] # (Nr, 1, 1)
848+
t_3d = tgrid[numpy.newaxis, :, numpy.newaxis] # (1, Nt, 1)
849+
phi_3d = phi_nodes[numpy.newaxis, numpy.newaxis, :] # (1, 1, phi_order)
863850
# Try fully vectorized: evaluate density at all (r, t, phi) at once
864851
# per theta node. Shape: (Nr, Nt, phi_order)
865852
_vectorized = True
866853
try:
867854
ct = ct_nodes[0]
868855
sintheta = numpy.sqrt(1.0 - ct**2)
869-
R_3d = rgrid[:, numpy.newaxis, numpy.newaxis] # (Nr, 1, 1)
870-
t_3d = tgrid[numpy.newaxis, :, numpy.newaxis] # (1, Nt, 1)
871-
phi_3d = phi_nodes[numpy.newaxis, numpy.newaxis, :] # (1, 1, phi_order)
872856
test = dens_func(R_3d * sintheta, R_3d * ct, phi_3d, t_3d)
873857
if numpy.shape(test) != (Nr, Nt, phi_order):
874858
_vectorized = False
875859
except (TypeError, ValueError):
876860
_vectorized = False
877861
if _vectorized:
878-
R_3d = rgrid[:, numpy.newaxis, numpy.newaxis]
879-
t_3d = tgrid[numpy.newaxis, :, numpy.newaxis]
880-
phi_3d = phi_nodes[numpy.newaxis, numpy.newaxis, :]
881862
for ict in range(costheta_order):
882863
ct = ct_nodes[ict]
883864
wt = ct_weights[ict]
@@ -973,7 +954,7 @@ def _quintic_hermite_ppoly_coeffs(vals, derivs, derivs2, dx):
973954
fp_R = derivs[..., 1:]
974955
fpp_L = derivs2[..., :-1]
975956
fpp_R = derivs2[..., 1:]
976-
h = dx # (Nr-1,)
957+
h = dx # (Nr-1,); broadcasts with (..., Nr-1) batch dims via numpy rules
977958
# Bernstein coefficients for quintic (degree 5) Hermite interpolant
978959
b = numpy.empty(f_L.shape[:-1] + (6,) + f_L.shape[-1:])
979960
b[..., 0, :] = f_L

galpy/potential/potential_c_ext/MultipoleExpansionPotential.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ struct multipole_data {
6464

6565
double *rho_scratch; // Nr doubles, scratch for rho reconstruction
6666

67-
// Preallocated Legendre buffers (avoids malloc/free per evaluation)
67+
// Preallocated Legendre buffers (allocated once during init, not per-call).
68+
// NOTE: these are shared across calls, so they are NOT thread-safe if
69+
// multiple threads evaluate this potential concurrently with different
70+
// (R, z) coordinates (costheta would differ, causing a race condition).
6871
double *P_buf; // Psize doubles for P_l^m
6972
double *dP_buf; // Psize doubles for dP_l^m/d(costheta)
7073

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def pytest_generate_tests(metafunc):
135135
] = -4.0 # time-dependent, C integration
136136
jactol[
137137
"mockFlatWeaklyTDNonaxiM3MultipoleExpansionPotential"
138-
] = -4.0 # time-dependent non-axi M=3, C integration
138+
] = -6.0 # time-dependent non-axi M=3, C integration
139139
# Now generate all inputs and run tests
140140
tols = [tol[p] if p in tol else tol["default"] for p in pots]
141141
jactols = [jactol[p] if p in jactol else tol["default"] for p in pots]

0 commit comments

Comments
 (0)