@@ -149,20 +149,6 @@ def amplitude_fn(freq: list[float]) -> complex:
149149
150150 return self .normalize (amplitude_fn )
151151
152- def _updated (self , update : dict ) -> MonitorData :
153- """Similar to ``updated_copy``, but does not actually copy components, for speed.
154-
155- Note
156- ----
157- This does **not** produce a copy of mutable objects, so e.g. if some of the data arrays
158- are not updated, they will point to the values in the original data. This method should
159- thus be used carefully.
160-
161- """
162- data_dict = self .dict ()
163- data_dict .update (update )
164- return type (self ).parse_obj (data_dict )
165-
166152 def _make_adjoint_sources (self , dataset_names : list [str ], fwidth : float ) -> list [Source ]:
167153 """Generate adjoint sources for this ``MonitorData`` instance."""
168154
@@ -261,7 +247,7 @@ def symmetry_expanded(self):
261247 if all (sym == 0 for sym in self .symmetry ):
262248 return self
263249
264- return self ._updated ( self ._symmetry_update_dict )
250+ return self .updated_copy ( ** self ._symmetry_update_dict , deep = False , validate = False )
265251
266252 @property
267253 def symmetry_expanded_copy (self ) -> AbstractFieldData :
@@ -780,23 +766,51 @@ def dot(
780766 fields_self = {key : field .conj () for key , field in fields_self .items ()}
781767
782768 fields_other = field_data ._interpolated_tangential_fields (self ._plane_grid_boundaries )
769+ dim1 , dim2 = self ._tangential_dims
770+ d_area = self ._diff_area
783771
784- # Drop size-1 dimensions in the other data
785- fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
772+ # After interpolation, the tangential coordinates should match. However, the two arrays
773+ # may either have the same shape along other dimensions, or be broadcastable.
774+ if (
775+ fields_self [next (iter (fields_self ))].shape
776+ == fields_other [next (iter (fields_other ))].shape
777+ ):
778+ # Arrays are same shape, so we can use numpy
779+ e_self_x_h_other = fields_self ["E" + dim1 ].values * fields_other ["H" + dim2 ].values
780+ e_self_x_h_other -= fields_self ["E" + dim2 ].values * fields_other ["H" + dim1 ].values
781+ h_self_x_e_other = fields_self ["H" + dim1 ].values * fields_other ["E" + dim2 ].values
782+ h_self_x_e_other -= fields_self ["H" + dim2 ].values * fields_other ["E" + dim1 ].values
783+ integrand = xr .DataArray (
784+ e_self_x_h_other - h_self_x_e_other , coords = fields_self ["E" + dim1 ].coords
785+ )
786+ integrand *= d_area
787+ else :
788+ # Broadcasting is needed, which may be complicated depending on the dimensions order.
789+ # Use xarray to handle robustly.
786790
787- # Cross products of fields
788- dim1 , dim2 = self ._tangential_dims
789- e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
790- e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
791- h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
792- h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
791+ # Drop size-1 dimensions in the other data
792+ fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
793793
794- # Integrate over plane
795- d_area = self ._diff_area
796- integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
794+ # Cross products of fields
795+ e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
796+ e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
797+ h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
798+ h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
799+ integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
797800
801+ # Integrate over plane
798802 return ModeAmpsDataArray (0.25 * integrand .sum (dim = d_area .dims ))
799803
804+ def _tangential_fields_match_coords (self , coords : ArrayFloat2D ) -> bool :
805+ """Check if the tangential fields already match given coords in the tangential plane."""
806+ for field in self ._tangential_fields .values ():
807+ for idim , dim in enumerate (self ._tangential_dims ):
808+ if field .coords [dim ].values .size != coords [idim ].size or not np .all (
809+ field .coords [dim ].values == coords [idim ]
810+ ):
811+ return False
812+ return True
813+
800814 def _interpolated_tangential_fields (self , coords : ArrayFloat2D ) -> dict [str , DataArray ]:
801815 """For 2D monitors, interpolate this fields to given coords in the tangential plane.
802816
@@ -811,6 +825,10 @@ def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, Dat
811825 """
812826 fields = self ._tangential_fields
813827
828+ # If coords already match, just return the tangential fields directly.
829+ if self ._tangential_fields_match_coords (coords ):
830+ return fields
831+
814832 # Interpolate if data has more than one coordinate along a dimension
815833 interp_dict = {"assume_sorted" : True }
816834 # If single coordinate, just sel "nearest", i.e. just propagate the same data everywhere
@@ -1711,10 +1729,12 @@ def overlap_sort(
17111729
17121730 # Normalizing the flux to 1, does not guarantee self terms of overlap integrals
17131731 # are also normalized to 1 when the non-conjugated product is used.
1714- if self .monitor .conjugated_dot_product :
1732+ data_expanded = self .symmetry_expanded
1733+ if data_expanded .monitor .conjugated_dot_product :
17151734 self_overlap = np .ones ((num_freqs , num_modes ))
17161735 else :
1717- self_overlap = np .abs (self .dot (self , self .monitor .conjugated_dot_product ).values )
1736+ self_overlap = data_expanded .dot (data_expanded , self .monitor .conjugated_dot_product )
1737+ self_overlap = np .abs (self_overlap .values )
17181738 threshold_array = overlap_thresh * self_overlap
17191739
17201740 # Compute sorting order and overlaps with neighboring frequencies
@@ -1727,20 +1747,19 @@ def overlap_sort(
17271747 # Sort in two directions from the base frequency
17281748 for step , last_ind in zip ([- 1 , 1 ], [- 1 , num_freqs ]):
17291749 # Start with the base frequency
1730- data_template = self ._isel (f = [f0_ind ])
1750+ data_template = data_expanded ._isel (f = [f0_ind ])
17311751
17321752 # March to lower/higher frequencies
17331753 for freq_id in range (f0_ind + step , last_ind , step ):
17341754 # Calculate threshold array for this frequency
1735- if not self .monitor .conjugated_dot_product :
1755+ if not data_expanded .monitor .conjugated_dot_product :
17361756 overlap_thresh = threshold_array [freq_id , :]
17371757 # Get next frequency to sort
1738- data_to_sort = self ._isel (f = [freq_id ])
1758+ data_to_sort = data_expanded ._isel (f = [freq_id ])
17391759 # Assign to the base frequency so that outer_dot will compare them
17401760 data_to_sort = data_to_sort ._assign_coords (f = [self .monitor .freqs [f0_ind ]])
17411761
17421762 # Compute "sorting w.r.t. to neighbor" and overlap values
1743-
17441763 sorting_one_mode , amps_one_mode = data_template ._find_ordering_one_freq (
17451764 data_to_sort , overlap_thresh
17461765 )
@@ -1756,8 +1775,8 @@ def overlap_sort(
17561775 for mode_ind in list (np .nonzero (overlap [freq_id , :] < overlap_thresh )[0 ]):
17571776 log .warning (
17581777 f"Mode '{ mode_ind } ' appears to undergo a discontinuous change "
1759- f"between frequencies '{ self .monitor .freqs [freq_id ]} ' "
1760- f"and '{ self .monitor .freqs [freq_id - step ]} ' "
1778+ f"between frequencies '{ data_expanded .monitor .freqs [freq_id ]} ' "
1779+ f"and '{ data_expanded .monitor .freqs [freq_id - step ]} ' "
17611780 f"(overlap: '{ overlap [freq_id , mode_ind ]:.2f} ')."
17621781 )
17631782
@@ -1796,7 +1815,7 @@ def _isel(self, **isel_kwargs):
17961815 for key , field in update_dict .items ()
17971816 if isinstance (field , DataArray )
17981817 }
1799- return self ._updated ( update = update_dict )
1818+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18001819
18011820 def _assign_coords (self , ** assign_coords_kwargs ):
18021821 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
@@ -1808,7 +1827,7 @@ def _assign_coords(self, **assign_coords_kwargs):
18081827 update_dict = {
18091828 key : field .assign_coords (** assign_coords_kwargs ) for key , field in update_dict .items ()
18101829 }
1811- return self ._updated ( update = update_dict )
1830+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18121831
18131832 def _find_ordering_one_freq (
18141833 self ,
@@ -2214,20 +2233,56 @@ def _apply_mode_reorder(self, sort_inds_2d):
22142233 Array of shape (num_freqs, num_modes) where each row is the
22152234 permutation to apply to the mode_index for that frequency.
22162235 """
2236+ sort_inds_2d = np .asarray (sort_inds_2d , dtype = int )
22172237 num_freqs , num_modes = sort_inds_2d .shape
2238+
2239+ # Fast no-op
2240+ identity = np .arange (num_modes )
2241+ if np .all (sort_inds_2d == identity [None , :]):
2242+ return self
2243+
22182244 modify_data = {}
2245+ new_mode_index_coord = identity
2246+
22192247 for key , data in self .data_arrs .items ():
22202248 if "mode_index" not in data .dims or "f" not in data .dims :
22212249 continue
2222- dims_orig = data .dims
2223- f_coord = data .coords ["f" ]
2224- slices = []
2225- for ifreq in range (num_freqs ):
2226- sl = data .isel (f = ifreq , mode_index = sort_inds_2d [ifreq ])
2227- slices .append (sl .assign_coords (mode_index = np .arange (num_modes )))
2228- # Concatenate along the 'f' dimension name and then restore original frequency coordinates
2229- data = xr .concat (slices , dim = "f" ).assign_coords (f = f_coord ).transpose (* dims_orig )
2230- modify_data [key ] = data
2250+
2251+ dims_orig = tuple (data .dims )
2252+ # Preserve coords (as numpy)
2253+ coords_out = {
2254+ k : (v .values if hasattr (v , "values" ) else np .asarray (v ))
2255+ for k , v in data .coords .items ()
2256+ }
2257+ f_axis = data .get_axis_num ("f" )
2258+ m_axis = data .get_axis_num ("mode_index" )
2259+
2260+ # Move axes directly to (f, ..., mode)
2261+ src_order = (
2262+ [f_axis ] + [ax for ax in range (data .ndim ) if ax not in (f_axis , m_axis )] + [m_axis ]
2263+ )
2264+ arr = np .moveaxis (data .data , src_order , range (data .ndim ))
2265+ nf , nm = arr .shape [0 ], arr .shape [- 1 ]
2266+ if nf != num_freqs or nm != num_modes :
2267+ raise DataError (
2268+ "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2269+ )
2270+
2271+ # Apply sorting
2272+ arr2 = arr .reshape (nf , - 1 , nm ) # (nf, Nlead, nm)
2273+ inds = sort_inds_2d [:, None , :] # (nf, 1, nm)
2274+ arr2_sorted = np .take_along_axis (arr2 , inds , axis = 2 )
2275+ arr_sorted = arr2_sorted .reshape (arr .shape )
2276+
2277+ # Move axes back to original order
2278+ arr_sorted = np .moveaxis (arr_sorted , range (data .ndim ), src_order )
2279+
2280+ # Update coords: keep f, reset mode_index to 0..num_modes-1
2281+ coords_out ["mode_index" ] = new_mode_index_coord
2282+ coords_out ["f" ] = data .coords ["f" ].values
2283+
2284+ modify_data [key ] = DataArray (arr_sorted , coords = coords_out , dims = dims_orig )
2285+
22312286 return self .updated_copy (** modify_data )
22322287
22332288 def sort_modes (
0 commit comments