@@ -565,59 +565,43 @@ def _parse_density(dens, ro, vo):
565565 except TypeError :
566566 numOfParam = 1
567567 # Handle astropy units
568- if has_t and _APY_LOADED :
569- # Check if time-dependent density returns Quantity and warn
570- param = [1.0 ] * numOfParam
571- try :
572- dens (* param , t = 0.0 ).to (units .kg / units .m ** 3 )
573- except (AttributeError , units .UnitConversionError , TypeError ):
574- pass
575- else :
576- import warnings
568+ if has_t and MultipoleExpansionPotential ._density_has_units (dens ):
569+ import warnings
577570
578- from ..util import galpyWarning
571+ from ..util import galpyWarning
579572
580- warnings .warn (
581- "Time-dependent density appears to return an astropy "
582- "Quantity. Unit conversion is not supported for "
583- "time-dependent densities; pass the density in internal "
584- "units (1/ro^3 * vo^2 / (4 pi G)) instead." ,
585- galpyWarning ,
573+ warnings .warn (
574+ "Time-dependent density appears to return an astropy "
575+ "Quantity. Unit conversion is not supported for "
576+ "time-dependent densities; pass the density in internal "
577+ "units (1/ro^3 * vo^2 / (4 pi G)) instead." ,
578+ galpyWarning ,
579+ )
580+ if not has_t and MultipoleExpansionPotential ._density_has_units (dens ):
581+ raw_dens = dens
582+ if numOfParam == 1 :
583+ return (
584+ lambda R , z , phi : conversion .parse_dens (
585+ raw_dens (numpy .sqrt (R ** 2 + z ** 2 )),
586+ ro = ro ,
587+ vo = vo ,
588+ ),
589+ False ,
590+ )
591+ elif numOfParam == 2 :
592+ return (
593+ lambda R , z , phi : conversion .parse_dens (
594+ raw_dens (R , z ), ro = ro , vo = vo
595+ ),
596+ False ,
586597 )
587- if not has_t and _APY_LOADED :
588- param = [1.0 ] * numOfParam
589- _dens_unit_output = False
590- try :
591- dens (* param ).to (units .kg / units .m ** 3 )
592- except (AttributeError , units .UnitConversionError ):
593- pass
594598 else :
595- _dens_unit_output = True
596- if _dens_unit_output :
597- raw_dens = dens
598- if numOfParam == 1 :
599- return (
600- lambda R , z , phi : conversion .parse_dens (
601- raw_dens (numpy .sqrt (R ** 2 + z ** 2 )),
602- ro = ro ,
603- vo = vo ,
604- ),
605- False ,
606- )
607- elif numOfParam == 2 :
608- return (
609- lambda R , z , phi : conversion .parse_dens (
610- raw_dens (R , z ), ro = ro , vo = vo
611- ),
612- False ,
613- )
614- else :
615- return (
616- lambda R , z , phi : conversion .parse_dens (
617- raw_dens (R , z , phi ), ro = ro , vo = vo
618- ),
619- False ,
620- )
599+ return (
600+ lambda R , z , phi : conversion .parse_dens (
601+ raw_dens (R , z , phi ), ro = ro , vo = vo
602+ ),
603+ False ,
604+ )
621605 # Wrap based on number of spatial params
622606 if has_t :
623607 if numOfParam == 1 :
@@ -812,15 +796,15 @@ def _compute_rho_lm_timedep(
812796 # Axisymmetric: no phi integral needed
813797 rho_cos_all = numpy .zeros ((Nt , Nr , L , 1 ))
814798 rho_sin_all = numpy .zeros ((Nt , Nr , L , 1 ))
799+ # Preallocate broadcasting arrays for vectorized path
800+ R_col = rgrid [:, numpy .newaxis ] # (Nr, 1)
801+ t_row = tgrid [numpy .newaxis , :] # (1, Nt)
815802 # Try fully vectorized: evaluate density at all (r, t) at once
816803 _vectorized = True
817804 try :
818- R_2d = rgrid [:, numpy .newaxis ] # (Nr, 1)
819- z_2d = numpy .zeros ((Nr , 1 ))
820- t_2d = tgrid [numpy .newaxis , :] # (1, Nt)
821805 ct = ct_nodes [0 ]
822806 sintheta = numpy .sqrt (1.0 - ct ** 2 )
823- test = dens_func (R_2d * sintheta , R_2d * ct , 0.0 , t_2d )
807+ test = dens_func (R_col * sintheta , R_col * ct , 0.0 , t_row )
824808 if numpy .shape (test ) != (Nr , Nt ):
825809 _vectorized = False
826810 except (TypeError , ValueError ):
@@ -829,14 +813,13 @@ def _compute_rho_lm_timedep(
829813 ct = ct_nodes [ict ]
830814 wt = ct_weights [ict ]
831815 sintheta = numpy .sqrt (1.0 - ct ** 2 )
832- R_col = rgrid [:, numpy .newaxis ] # (Nr, 1)
833816 if _vectorized :
834817 # (Nr, Nt) via broadcasting
835818 rho_spatial = dens_func (
836819 R_col * sintheta ,
837820 R_col * ct ,
838821 0.0 ,
839- tgrid [ numpy . newaxis , :] ,
822+ t_row ,
840823 ).T # -> (Nt, Nr)
841824 else :
842825 rho_spatial = numpy .zeros ((Nt , Nr ))
@@ -860,24 +843,22 @@ def _compute_rho_lm_timedep(
860843 sin_mphi = numpy .sin (numpy .outer (phi_nodes , m_arr )) # (phi_order, M)
861844 rho_cos_all = numpy .zeros ((Nt , Nr , L , M ))
862845 rho_sin_all = numpy .zeros ((Nt , Nr , L , M ))
846+ # Preallocate broadcasting arrays for vectorized path
847+ R_3d = rgrid [:, numpy .newaxis , numpy .newaxis ] # (Nr, 1, 1)
848+ t_3d = tgrid [numpy .newaxis , :, numpy .newaxis ] # (1, Nt, 1)
849+ phi_3d = phi_nodes [numpy .newaxis , numpy .newaxis , :] # (1, 1, phi_order)
863850 # Try fully vectorized: evaluate density at all (r, t, phi) at once
864851 # per theta node. Shape: (Nr, Nt, phi_order)
865852 _vectorized = True
866853 try :
867854 ct = ct_nodes [0 ]
868855 sintheta = numpy .sqrt (1.0 - ct ** 2 )
869- R_3d = rgrid [:, numpy .newaxis , numpy .newaxis ] # (Nr, 1, 1)
870- t_3d = tgrid [numpy .newaxis , :, numpy .newaxis ] # (1, Nt, 1)
871- phi_3d = phi_nodes [numpy .newaxis , numpy .newaxis , :] # (1, 1, phi_order)
872856 test = dens_func (R_3d * sintheta , R_3d * ct , phi_3d , t_3d )
873857 if numpy .shape (test ) != (Nr , Nt , phi_order ):
874858 _vectorized = False
875859 except (TypeError , ValueError ):
876860 _vectorized = False
877861 if _vectorized :
878- R_3d = rgrid [:, numpy .newaxis , numpy .newaxis ]
879- t_3d = tgrid [numpy .newaxis , :, numpy .newaxis ]
880- phi_3d = phi_nodes [numpy .newaxis , numpy .newaxis , :]
881862 for ict in range (costheta_order ):
882863 ct = ct_nodes [ict ]
883864 wt = ct_weights [ict ]
@@ -973,7 +954,7 @@ def _quintic_hermite_ppoly_coeffs(vals, derivs, derivs2, dx):
973954 fp_R = derivs [..., 1 :]
974955 fpp_L = derivs2 [..., :- 1 ]
975956 fpp_R = derivs2 [..., 1 :]
976- h = dx # (Nr-1,)
957+ h = dx # (Nr-1,); broadcasts with (..., Nr-1) batch dims via numpy rules
977958 # Bernstein coefficients for quintic (degree 5) Hermite interpolant
978959 b = numpy .empty (f_L .shape [:- 1 ] + (6 ,) + f_L .shape [- 1 :])
979960 b [..., 0 , :] = f_L
0 commit comments