@@ -149,20 +149,6 @@ def amplitude_fn(freq: list[float]) -> complex:
149149
150150 return self .normalize (amplitude_fn )
151151
152- def _updated (self , update : dict ) -> MonitorData :
153- """Similar to ``updated_copy``, but does not actually copy components, for speed.
154-
155- Note
156- ----
157- This does **not** produce a copy of mutable objects, so e.g. if some of the data arrays
158- are not updated, they will point to the values in the original data. This method should
159- thus be used carefully.
160-
161- """
162- data_dict = self .dict ()
163- data_dict .update (update )
164- return type (self ).parse_obj (data_dict )
165-
166152 def _make_adjoint_sources (self , dataset_names : list [str ], fwidth : float ) -> list [Source ]:
167153 """Generate adjoint sources for this ``MonitorData`` instance."""
168154
@@ -261,7 +247,7 @@ def symmetry_expanded(self):
261247 if all (sym == 0 for sym in self .symmetry ):
262248 return self
263249
264- return self ._updated ( self ._symmetry_update_dict )
250+ return self .updated_copy ( ** self ._symmetry_update_dict , deep = False , validate = False )
265251
266252 @property
267253 def symmetry_expanded_copy (self ) -> AbstractFieldData :
@@ -780,21 +766,41 @@ def dot(
780766 fields_self = {key : field .conj () for key , field in fields_self .items ()}
781767
782768 fields_other = field_data ._interpolated_tangential_fields (self ._plane_grid_boundaries )
769+ dim1 , dim2 = self ._tangential_dims
770+ d_area = self ._diff_area
783771
784- # Drop size-1 dimensions in the other data
785- fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
772+ try :
773+ for field_self , field_other in zip (fields_self .values (), fields_other .values ()):
774+ for key in self ._tangential_dims :
775+ if not np .all (field_self .coords [key ] == field_other .coords [key ]):
776+ raise ValueError ("Coordinates do not match." )
777+
778+ # Tangential coordinates match, so we try to use .values for speed.
779+ # This will work if other coordinates match dimensions or are broadcastable.
780+ # This is OK as we do not enforce frequencies or mode indexes to be the same.
781+ # If this fails it will fallback to the xarray handling below.
782+ e_self_x_h_other = fields_self ["E" + dim1 ].values * fields_other ["H" + dim2 ].values
783+ e_self_x_h_other -= fields_self ["E" + dim2 ].values * fields_other ["H" + dim1 ].values
784+ h_self_x_e_other = fields_self ["H" + dim1 ].values * fields_other ["E" + dim2 ].values
785+ h_self_x_e_other -= fields_self ["H" + dim2 ].values * fields_other ["E" + dim1 ].values
786+ integrand = xr .DataArray (
787+ e_self_x_h_other - h_self_x_e_other , coords = fields_self ["E" + dim1 ].coords
788+ )
789+ integrand *= d_area
790+ except Exception :
791+ # Catching a broad exception here in case anything went wrong in the check.
786792
787- # Cross products of fields
788- dim1 , dim2 = self ._tangential_dims
789- e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
790- e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
791- h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
792- h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
793+ # Drop size-1 dimensions in the other data
794+ fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
793795
794- # Integrate over plane
795- d_area = self ._diff_area
796- integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
796+ # Cross products of fields
797+ e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
798+ e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
799+ h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
800+ h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
801+ integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
797802
803+ # Integrate over plane
798804 return ModeAmpsDataArray (0.25 * integrand .sum (dim = d_area .dims ))
799805
800806 def _interpolated_tangential_fields (self , coords : ArrayFloat2D ) -> dict [str , DataArray ]:
@@ -811,6 +817,19 @@ def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, Dat
811817 """
812818 fields = self ._tangential_fields
813819
820+ try :
821+ # If coords already match, just return the tangential fields directly.
822+ # Using try: except for flow control.
823+ for field in fields .values ():
824+ for idim , dim in enumerate (self ._tangential_dims ):
825+ if field .coords [dim ].values .size != coords [idim ].size or not np .all (
826+ field .coords [dim ].values == coords [idim ]
827+ ):
828+ raise ValueError ("Coordinates do not match." )
829+ return fields
830+ except ValueError :
831+ pass
832+
814833 # Interpolate if data has more than one coordinate along a dimension
815834 interp_dict = {"assume_sorted" : True }
816835 # If single coordinate, just sel "nearest", i.e. just propagate the same data everywhere
@@ -1711,10 +1730,12 @@ def overlap_sort(
17111730
17121731 # Normalizing the flux to 1, does not guarantee self terms of overlap integrals
17131732 # are also normalized to 1 when the non-conjugated product is used.
1714- if self .monitor .conjugated_dot_product :
1733+ data_expanded = self .symmetry_expanded
1734+ if data_expanded .monitor .conjugated_dot_product :
17151735 self_overlap = np .ones ((num_freqs , num_modes ))
17161736 else :
1717- self_overlap = np .abs (self .dot (self , self .monitor .conjugated_dot_product ).values )
1737+ self_overlap = data_expanded .dot (data_expanded , self .monitor .conjugated_dot_product )
1738+ self_overlap = np .abs (self_overlap .values )
17181739 threshold_array = overlap_thresh * self_overlap
17191740
17201741 # Compute sorting order and overlaps with neighboring frequencies
@@ -1727,20 +1748,19 @@ def overlap_sort(
17271748 # Sort in two directions from the base frequency
17281749 for step , last_ind in zip ([- 1 , 1 ], [- 1 , num_freqs ]):
17291750 # Start with the base frequency
1730- data_template = self ._isel (f = [f0_ind ])
1751+ data_template = data_expanded ._isel (f = [f0_ind ])
17311752
17321753 # March to lower/higher frequencies
17331754 for freq_id in range (f0_ind + step , last_ind , step ):
17341755 # Calculate threshold array for this frequency
1735- if not self .monitor .conjugated_dot_product :
1756+ if not data_expanded .monitor .conjugated_dot_product :
17361757 overlap_thresh = threshold_array [freq_id , :]
17371758 # Get next frequency to sort
1738- data_to_sort = self ._isel (f = [freq_id ])
1759+ data_to_sort = data_expanded ._isel (f = [freq_id ])
17391760 # Assign to the base frequency so that outer_dot will compare them
17401761 data_to_sort = data_to_sort ._assign_coords (f = [self .monitor .freqs [f0_ind ]])
17411762
17421763 # Compute "sorting w.r.t. to neighbor" and overlap values
1743-
17441764 sorting_one_mode , amps_one_mode = data_template ._find_ordering_one_freq (
17451765 data_to_sort , overlap_thresh
17461766 )
@@ -1756,8 +1776,8 @@ def overlap_sort(
17561776 for mode_ind in list (np .nonzero (overlap [freq_id , :] < overlap_thresh )[0 ]):
17571777 log .warning (
17581778 f"Mode '{ mode_ind } ' appears to undergo a discontinuous change "
1759- f"between frequencies '{ self .monitor .freqs [freq_id ]} ' "
1760- f"and '{ self .monitor .freqs [freq_id - step ]} ' "
1779+ f"between frequencies '{ data_expanded .monitor .freqs [freq_id ]} ' "
1780+ f"and '{ data_expanded .monitor .freqs [freq_id - step ]} ' "
17611781 f"(overlap: '{ overlap [freq_id , mode_ind ]:.2f} ')."
17621782 )
17631783
@@ -1796,7 +1816,7 @@ def _isel(self, **isel_kwargs):
17961816 for key , field in update_dict .items ()
17971817 if isinstance (field , DataArray )
17981818 }
1799- return self ._updated ( update = update_dict )
1819+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18001820
18011821 def _assign_coords (self , ** assign_coords_kwargs ):
18021822 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
@@ -1808,7 +1828,7 @@ def _assign_coords(self, **assign_coords_kwargs):
18081828 update_dict = {
18091829 key : field .assign_coords (** assign_coords_kwargs ) for key , field in update_dict .items ()
18101830 }
1811- return self ._updated ( update = update_dict )
1831+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18121832
18131833 def _find_ordering_one_freq (
18141834 self ,
@@ -2214,20 +2234,58 @@ def _apply_mode_reorder(self, sort_inds_2d):
22142234 Array of shape (num_freqs, num_modes) where each row is the
22152235 permutation to apply to the mode_index for that frequency.
22162236 """
2237+ sort_inds_2d = np .asarray (sort_inds_2d , dtype = int )
22172238 num_freqs , num_modes = sort_inds_2d .shape
2239+
2240+ # Fast no-op
2241+ identity = np .arange (num_modes )
2242+ if np .all (sort_inds_2d == identity [None , :]):
2243+ return self
2244+
22182245 modify_data = {}
2246+ new_mode_index_coord = identity
2247+
22192248 for key , data in self .data_arrs .items ():
22202249 if "mode_index" not in data .dims or "f" not in data .dims :
22212250 continue
2222- dims_orig = data .dims
2223- f_coord = data .coords ["f" ]
2224- slices = []
2225- for ifreq in range (num_freqs ):
2226- sl = data .isel (f = ifreq , mode_index = sort_inds_2d [ifreq ])
2227- slices .append (sl .assign_coords (mode_index = np .arange (num_modes )))
2228- # Concatenate along the 'f' dimension name and then restore original frequency coordinates
2229- data = xr .concat (slices , dim = "f" ).assign_coords (f = f_coord ).transpose (* dims_orig )
2230- modify_data [key ] = data
2251+
2252+ dims_orig = tuple (data .dims )
2253+ # Preserve coords (as numpy)
2254+ coords_out = {
2255+ k : (v .values if hasattr (v , "values" ) else np .asarray (v ))
2256+ for k , v in data .coords .items ()
2257+ }
2258+ f_axis = data .get_axis_num ("f" )
2259+ m_axis = data .get_axis_num ("mode_index" )
2260+
2261+ # Move axes so array is (..., f, mode)
2262+ move_order = [ax for ax in range (data .ndim ) if ax not in (f_axis , m_axis )] + [
2263+ f_axis ,
2264+ m_axis ,
2265+ ]
2266+ arr = np .moveaxis (data .data , move_order , range (data .ndim ))
2267+ lead_shape = arr .shape [:- 2 ]
2268+ nf , nm = arr .shape [- 2 ], arr .shape [- 1 ]
2269+ if nf != num_freqs or nm != num_modes :
2270+ raise DataError (
2271+ "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2272+ )
2273+
2274+ # Vectorized gather: reshape to (nf, Nlead, nm), gather along last axis
2275+ arr3 = arr .reshape ((- 1 , nf , nm )).transpose (1 , 0 , 2 ) # (nf, Nlead, nm)
2276+ inds = sort_inds_2d [:, None , :] # (nf, 1, nm)
2277+ arr3_sorted = np .take_along_axis (arr3 , inds , axis = 2 )
2278+ arr_sorted = arr3_sorted .transpose (1 , 0 , 2 ).reshape (* lead_shape , nf , nm )
2279+
2280+ # Move axes back to original order
2281+ arr_sorted = np .moveaxis (arr_sorted , range (data .ndim ), move_order )
2282+
2283+ # Update coords: keep f, reset mode_index to 0..num_modes-1
2284+ coords_out ["mode_index" ] = new_mode_index_coord
2285+ coords_out ["f" ] = data .coords ["f" ].values
2286+
2287+ modify_data [key ] = DataArray (arr_sorted , coords = coords_out , dims = dims_orig )
2288+
22312289 return self .updated_copy (** modify_data )
22322290
22332291 def sort_modes (
0 commit comments