@@ -149,20 +149,6 @@ def amplitude_fn(freq: list[float]) -> complex:
149149
150150 return self .normalize (amplitude_fn )
151151
152- def _updated (self , update : dict ) -> MonitorData :
153- """Similar to ``updated_copy``, but does not actually copy components, for speed.
154-
155- Note
156- ----
157- This does **not** produce a copy of mutable objects, so e.g. if some of the data arrays
158- are not updated, they will point to the values in the original data. This method should
159- thus be used carefully.
160-
161- """
162- data_dict = self .dict ()
163- data_dict .update (update )
164- return type (self ).parse_obj (data_dict )
165-
166152 def _make_adjoint_sources (self , dataset_names : list [str ], fwidth : float ) -> list [Source ]:
167153 """Generate adjoint sources for this ``MonitorData`` instance."""
168154
@@ -261,7 +247,7 @@ def symmetry_expanded(self):
261247 if all (sym == 0 for sym in self .symmetry ):
262248 return self
263249
264- return self ._updated ( self ._symmetry_update_dict )
250+ return self .updated_copy ( ** self ._symmetry_update_dict , deep = False , validate = False )
265251
266252 @property
267253 def symmetry_expanded_copy (self ) -> AbstractFieldData :
@@ -780,21 +766,38 @@ def dot(
780766 fields_self = {key : field .conj () for key , field in fields_self .items ()}
781767
782768 fields_other = field_data ._interpolated_tangential_fields (self ._plane_grid_boundaries )
769+ dim1 , dim2 = self ._tangential_dims
770+ d_area = self ._diff_area
783771
784- # Drop size-1 dimensions in the other data
785- fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
772+ try :
773+ for field_self , field_other in zip (fields_self .values (), fields_other .values ()):
774+ for key , val in field_self .coords :
775+ if not np .all (val .values == field_other .coords [key ].values ):
776+ raise ValueError ("Coordinates do not match." )
777+
778+ # Coordinates match, so we can use .values for speed
779+ e_self_x_h_other = fields_self ["E" + dim1 ].values * fields_other ["H" + dim2 ].values
780+ e_self_x_h_other -= fields_self ["E" + dim2 ].values * fields_other ["H" + dim1 ].values
781+ h_self_x_e_other = fields_self ["H" + dim1 ].values * fields_other ["E" + dim2 ].values
782+ h_self_x_e_other -= fields_self ["H" + dim2 ].values * fields_other ["E" + dim1 ].values
783+ integrand = xr .DataArray (
784+ e_self_x_h_other - h_self_x_e_other , coords = fields_self ["E" + dim1 ].coords
785+ )
786+ integrand *= d_area
787+ except : # noqa: E722
788+ # Catching a broad exception here in case anything went wrong in the check.
786789
787- # Cross products of fields
788- dim1 , dim2 = self ._tangential_dims
789- e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
790- e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
791- h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
792- h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
790+ # Drop size-1 dimensions in the other data
791+ fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
793792
794- # Integrate over plane
795- d_area = self ._diff_area
796- integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
793+ # Cross products of fields
794+ e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
795+ e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
796+ h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
797+ h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
798+ integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
797799
800+ # Integrate over plane
798801 return ModeAmpsDataArray (0.25 * integrand .sum (dim = d_area .dims ))
799802
800803 def _interpolated_tangential_fields (self , coords : ArrayFloat2D ) -> dict [str , DataArray ]:
@@ -811,6 +814,19 @@ def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, Dat
811814 """
812815 fields = self ._tangential_fields
813816
817+ try :
818+ # If coords already match, just return the tangential fields directly
819+ for field in fields .values ():
820+ for idim , dim in enumerate (self ._tangential_dims ):
821+ if field .coords [dim ].values .size != coords [idim ].size or not np .all (
822+ field .coords [dim ].values == coords [idim ]
823+ ):
824+ raise ValueError ("Coordinates do not match." )
825+ return fields
826+ except : # noqa: E722
827+ # Catching a broad exception here in case anything went wrong in the check.
828+ pass
829+
814830 # Interpolate if data has more than one coordinate along a dimension
815831 interp_dict = {"assume_sorted" : True }
816832 # If single coordinate, just sel "nearest", i.e. just propagate the same data everywhere
@@ -1711,10 +1727,12 @@ def overlap_sort(
17111727
17121728 # Normalizing the flux to 1, does not guarantee self terms of overlap integrals
17131729 # are also normalized to 1 when the non-conjugated product is used.
1714- if self .monitor .conjugated_dot_product :
1730+ data_expanded = self .symmetry_expanded
1731+ if data_expanded .monitor .conjugated_dot_product :
17151732 self_overlap = np .ones ((num_freqs , num_modes ))
17161733 else :
1717- self_overlap = np .abs (self .dot (self , self .monitor .conjugated_dot_product ).values )
1734+ self_overlap = data_expanded .dot (data_expanded , self .monitor .conjugated_dot_product )
1735+ self_overlap = np .abs (self_overlap .values )
17181736 threshold_array = overlap_thresh * self_overlap
17191737
17201738 # Compute sorting order and overlaps with neighboring frequencies
@@ -1727,20 +1745,19 @@ def overlap_sort(
17271745 # Sort in two directions from the base frequency
17281746 for step , last_ind in zip ([- 1 , 1 ], [- 1 , num_freqs ]):
17291747 # Start with the base frequency
1730- data_template = self ._isel (f = [f0_ind ])
1748+ data_template = data_expanded ._isel (f = [f0_ind ])
17311749
17321750 # March to lower/higher frequencies
17331751 for freq_id in range (f0_ind + step , last_ind , step ):
17341752 # Calculate threshold array for this frequency
1735- if not self .monitor .conjugated_dot_product :
1753+ if not data_expanded .monitor .conjugated_dot_product :
17361754 overlap_thresh = threshold_array [freq_id , :]
17371755 # Get next frequency to sort
1738- data_to_sort = self ._isel (f = [freq_id ])
1756+ data_to_sort = data_expanded ._isel (f = [freq_id ])
17391757 # Assign to the base frequency so that outer_dot will compare them
17401758 data_to_sort = data_to_sort ._assign_coords (f = [self .monitor .freqs [f0_ind ]])
17411759
17421760 # Compute "sorting w.r.t. to neighbor" and overlap values
1743-
17441761 sorting_one_mode , amps_one_mode = data_template ._find_ordering_one_freq (
17451762 data_to_sort , overlap_thresh
17461763 )
@@ -1756,8 +1773,8 @@ def overlap_sort(
17561773 for mode_ind in list (np .nonzero (overlap [freq_id , :] < overlap_thresh )[0 ]):
17571774 log .warning (
17581775 f"Mode '{ mode_ind } ' appears to undergo a discontinuous change "
1759- f"between frequencies '{ self .monitor .freqs [freq_id ]} ' "
1760- f"and '{ self .monitor .freqs [freq_id - step ]} ' "
1776+ f"between frequencies '{ data_expanded .monitor .freqs [freq_id ]} ' "
1777+ f"and '{ data_expanded .monitor .freqs [freq_id - step ]} ' "
17611778 f"(overlap: '{ overlap [freq_id , mode_ind ]:.2f} ')."
17621779 )
17631780
@@ -1796,7 +1813,7 @@ def _isel(self, **isel_kwargs):
17961813 for key , field in update_dict .items ()
17971814 if isinstance (field , DataArray )
17981815 }
1799- return self ._updated ( update = update_dict )
1816+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18001817
18011818 def _assign_coords (self , ** assign_coords_kwargs ):
18021819 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
@@ -1808,7 +1825,7 @@ def _assign_coords(self, **assign_coords_kwargs):
18081825 update_dict = {
18091826 key : field .assign_coords (** assign_coords_kwargs ) for key , field in update_dict .items ()
18101827 }
1811- return self ._updated ( update = update_dict )
1828+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18121829
18131830 def _find_ordering_one_freq (
18141831 self ,
@@ -2214,20 +2231,58 @@ def _apply_mode_reorder(self, sort_inds_2d):
22142231 Array of shape (num_freqs, num_modes) where each row is the
22152232 permutation to apply to the mode_index for that frequency.
22162233 """
2234+ sort_inds_2d = np .asarray (sort_inds_2d , dtype = int )
22172235 num_freqs , num_modes = sort_inds_2d .shape
2236+
2237+ # Fast no-op
2238+ identity = np .arange (num_modes )
2239+ if np .all (sort_inds_2d == identity [None , :]):
2240+ return self
2241+
22182242 modify_data = {}
2243+ new_mode_index_coord = identity
2244+
22192245 for key , data in self .data_arrs .items ():
22202246 if "mode_index" not in data .dims or "f" not in data .dims :
22212247 continue
2222- dims_orig = data .dims
2223- f_coord = data .coords ["f" ]
2224- slices = []
2225- for ifreq in range (num_freqs ):
2226- sl = data .isel (f = ifreq , mode_index = sort_inds_2d [ifreq ])
2227- slices .append (sl .assign_coords (mode_index = np .arange (num_modes )))
2228- # Concatenate along the 'f' dimension name and then restore original frequency coordinates
2229- data = xr .concat (slices , dim = "f" ).assign_coords (f = f_coord ).transpose (* dims_orig )
2230- modify_data [key ] = data
2248+
2249+ dims_orig = tuple (data .dims )
2250+ # Preserve coords (as numpy)
2251+ coords_out = {
2252+ k : (v .values if hasattr (v , "values" ) else np .asarray (v ))
2253+ for k , v in data .coords .items ()
2254+ }
2255+ f_axis = data .get_axis_num ("f" )
2256+ m_axis = data .get_axis_num ("mode_index" )
2257+
2258+ # Move axes so array is (..., f, mode)
2259+ move_order = [ax for ax in range (data .ndim ) if ax not in (f_axis , m_axis )] + [
2260+ f_axis ,
2261+ m_axis ,
2262+ ]
2263+ arr = np .moveaxis (data .data , move_order , range (data .ndim ))
2264+ lead_shape = arr .shape [:- 2 ]
2265+ nf , nm = arr .shape [- 2 ], arr .shape [- 1 ]
2266+ if nf != num_freqs or nm != num_modes :
2267+ raise DataError (
2268+ "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2269+ )
2270+
2271+ # Vectorized gather: reshape to (nf, Nlead, nm), gather along last axis
2272+ arr3 = arr .reshape ((- 1 , nf , nm )).transpose (1 , 0 , 2 ) # (nf, Nlead, nm)
2273+ inds = sort_inds_2d [:, None , :] # (nf, 1, nm)
2274+ arr3_sorted = np .take_along_axis (arr3 , inds , axis = 2 )
2275+ arr_sorted = arr3_sorted .transpose (1 , 0 , 2 ).reshape (* lead_shape , nf , nm )
2276+
2277+ # Move axes back to original order
2278+ arr_sorted = np .moveaxis (arr_sorted , range (data .ndim ), move_order )
2279+
2280+ # Update coords: keep f, reset mode_index to 0..num_modes-1
2281+ coords_out ["mode_index" ] = new_mode_index_coord
2282+ coords_out ["f" ] = data .coords ["f" ].values
2283+
2284+ modify_data [key ] = DataArray (arr_sorted , coords = coords_out , dims = dims_orig )
2285+
22312286 return self .updated_copy (** modify_data )
22322287
22332288 def sort_modes (
0 commit comments