Skip to content

Commit ef492c1

Browse files
committed
fix: Updates to speed up some mode data operations
1 parent dc01be1 commit ef492c1

File tree

2 files changed

+101
-47
lines changed

2 files changed

+101
-47
lines changed

tidy3d/components/data/monitor_data.py

Lines changed: 100 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,6 @@ def amplitude_fn(freq: list[float]) -> complex:
149149

150150
return self.normalize(amplitude_fn)
151151

152-
def _updated(self, update: dict) -> MonitorData:
153-
"""Similar to ``updated_copy``, but does not actually copy components, for speed.
154-
155-
Note
156-
----
157-
This does **not** produce a copy of mutable objects, so e.g. if some of the data arrays
158-
are not updated, they will point to the values in the original data. This method should
159-
thus be used carefully.
160-
161-
"""
162-
data_dict = self.dict()
163-
data_dict.update(update)
164-
return type(self).parse_obj(data_dict)
165-
166152
def _make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[Source]:
167153
"""Generate adjoint sources for this ``MonitorData`` instance."""
168154

@@ -261,7 +247,7 @@ def symmetry_expanded(self):
261247
if all(sym == 0 for sym in self.symmetry):
262248
return self
263249

264-
return self._updated(self._symmetry_update_dict)
250+
return self.updated_copy(**self._symmetry_update_dict, deep=False, validate=False)
265251

266252
@property
267253
def symmetry_expanded_copy(self) -> AbstractFieldData:
@@ -780,21 +766,38 @@ def dot(
780766
fields_self = {key: field.conj() for key, field in fields_self.items()}
781767

782768
fields_other = field_data._interpolated_tangential_fields(self._plane_grid_boundaries)
769+
dim1, dim2 = self._tangential_dims
770+
d_area = self._diff_area
783771

784-
# Drop size-1 dimensions in the other data
785-
fields_other = {key: field.squeeze(drop=True) for key, field in fields_other.items()}
772+
try:
773+
for field_self, field_other in zip(fields_self.values(), fields_other.values()):
774+
for key, val in field_self.coords:
775+
if not np.all(val.values == field_other.coords[key].values):
776+
raise ValueError("Coordinates do not match.")
777+
778+
# Coordinates match, so we can use .values for speed
779+
e_self_x_h_other = fields_self["E" + dim1].values * fields_other["H" + dim2].values
780+
e_self_x_h_other -= fields_self["E" + dim2].values * fields_other["H" + dim1].values
781+
h_self_x_e_other = fields_self["H" + dim1].values * fields_other["E" + dim2].values
782+
h_self_x_e_other -= fields_self["H" + dim2].values * fields_other["E" + dim1].values
783+
integrand = xr.DataArray(
784+
e_self_x_h_other - h_self_x_e_other, coords=fields_self["E" + dim1].coords
785+
)
786+
integrand *= d_area
787+
except: # noqa: E722
788+
# Catching a broad exception here in case anything went wrong in the check.
786789

787-
# Cross products of fields
788-
dim1, dim2 = self._tangential_dims
789-
e_self_x_h_other = fields_self["E" + dim1] * fields_other["H" + dim2]
790-
e_self_x_h_other -= fields_self["E" + dim2] * fields_other["H" + dim1]
791-
h_self_x_e_other = fields_self["H" + dim1] * fields_other["E" + dim2]
792-
h_self_x_e_other -= fields_self["H" + dim2] * fields_other["E" + dim1]
790+
# Drop size-1 dimensions in the other data
791+
fields_other = {key: field.squeeze(drop=True) for key, field in fields_other.items()}
793792

794-
# Integrate over plane
795-
d_area = self._diff_area
796-
integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
793+
# Cross products of fields
794+
e_self_x_h_other = fields_self["E" + dim1] * fields_other["H" + dim2]
795+
e_self_x_h_other -= fields_self["E" + dim2] * fields_other["H" + dim1]
796+
h_self_x_e_other = fields_self["H" + dim1] * fields_other["E" + dim2]
797+
h_self_x_e_other -= fields_self["H" + dim2] * fields_other["E" + dim1]
798+
integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
797799

800+
# Integrate over plane
798801
return ModeAmpsDataArray(0.25 * integrand.sum(dim=d_area.dims))
799802

800803
def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, DataArray]:
@@ -811,6 +814,19 @@ def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, Dat
811814
"""
812815
fields = self._tangential_fields
813816

817+
try:
818+
# If coords already match, just return the tangential fields directly
819+
for field in fields.values():
820+
for idim, dim in enumerate(self._tangential_dims):
821+
if field.coords[dim].values.size != coords[idim].size or not np.all(
822+
field.coords[dim].values == coords[idim]
823+
):
824+
raise ValueError("Coordinates do not match.")
825+
return fields
826+
except: # noqa: E722
827+
# Catching a broad exception here in case anything went wrong in the check.
828+
pass
829+
814830
# Interpolate if data has more than one coordinate along a dimension
815831
interp_dict = {"assume_sorted": True}
816832
# If single coordinate, just sel "nearest", i.e. just propagate the same data everywhere
@@ -1711,10 +1727,12 @@ def overlap_sort(
17111727

17121728
# Normalizing the flux to 1, does not guarantee self terms of overlap integrals
17131729
# are also normalized to 1 when the non-conjugated product is used.
1714-
if self.monitor.conjugated_dot_product:
1730+
data_expanded = self.symmetry_expanded
1731+
if data_expanded.monitor.conjugated_dot_product:
17151732
self_overlap = np.ones((num_freqs, num_modes))
17161733
else:
1717-
self_overlap = np.abs(self.dot(self, self.monitor.conjugated_dot_product).values)
1734+
self_overlap = data_expanded.dot(data_expanded, self.monitor.conjugated_dot_product)
1735+
self_overlap = np.abs(self_overlap.values)
17181736
threshold_array = overlap_thresh * self_overlap
17191737

17201738
# Compute sorting order and overlaps with neighboring frequencies
@@ -1727,20 +1745,19 @@ def overlap_sort(
17271745
# Sort in two directions from the base frequency
17281746
for step, last_ind in zip([-1, 1], [-1, num_freqs]):
17291747
# Start with the base frequency
1730-
data_template = self._isel(f=[f0_ind])
1748+
data_template = data_expanded._isel(f=[f0_ind])
17311749

17321750
# March to lower/higher frequencies
17331751
for freq_id in range(f0_ind + step, last_ind, step):
17341752
# Calculate threshold array for this frequency
1735-
if not self.monitor.conjugated_dot_product:
1753+
if not data_expanded.monitor.conjugated_dot_product:
17361754
overlap_thresh = threshold_array[freq_id, :]
17371755
# Get next frequency to sort
1738-
data_to_sort = self._isel(f=[freq_id])
1756+
data_to_sort = data_expanded._isel(f=[freq_id])
17391757
# Assign to the base frequency so that outer_dot will compare them
17401758
data_to_sort = data_to_sort._assign_coords(f=[self.monitor.freqs[f0_ind]])
17411759

17421760
# Compute "sorting w.r.t. to neighbor" and overlap values
1743-
17441761
sorting_one_mode, amps_one_mode = data_template._find_ordering_one_freq(
17451762
data_to_sort, overlap_thresh
17461763
)
@@ -1756,8 +1773,8 @@ def overlap_sort(
17561773
for mode_ind in list(np.nonzero(overlap[freq_id, :] < overlap_thresh)[0]):
17571774
log.warning(
17581775
f"Mode '{mode_ind}' appears to undergo a discontinuous change "
1759-
f"between frequencies '{self.monitor.freqs[freq_id]}' "
1760-
f"and '{self.monitor.freqs[freq_id - step]}' "
1776+
f"between frequencies '{data_expanded.monitor.freqs[freq_id]}' "
1777+
f"and '{data_expanded.monitor.freqs[freq_id - step]}' "
17611778
f"(overlap: '{overlap[freq_id, mode_ind]:.2f}')."
17621779
)
17631780

@@ -1796,7 +1813,7 @@ def _isel(self, **isel_kwargs):
17961813
for key, field in update_dict.items()
17971814
if isinstance(field, DataArray)
17981815
}
1799-
return self._updated(update=update_dict)
1816+
return self.updated_copy(**update_dict, deep=False, validate=False)
18001817

18011818
def _assign_coords(self, **assign_coords_kwargs):
18021819
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
@@ -1808,7 +1825,7 @@ def _assign_coords(self, **assign_coords_kwargs):
18081825
update_dict = {
18091826
key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items()
18101827
}
1811-
return self._updated(update=update_dict)
1828+
return self.updated_copy(**update_dict, deep=False, validate=False)
18121829

18131830
def _find_ordering_one_freq(
18141831
self,
@@ -2214,20 +2231,58 @@ def _apply_mode_reorder(self, sort_inds_2d):
22142231
Array of shape (num_freqs, num_modes) where each row is the
22152232
permutation to apply to the mode_index for that frequency.
22162233
"""
2234+
sort_inds_2d = np.asarray(sort_inds_2d, dtype=int)
22172235
num_freqs, num_modes = sort_inds_2d.shape
2236+
2237+
# Fast no-op
2238+
identity = np.arange(num_modes)
2239+
if np.all(sort_inds_2d == identity[None, :]):
2240+
return self
2241+
22182242
modify_data = {}
2243+
new_mode_index_coord = identity
2244+
22192245
for key, data in self.data_arrs.items():
22202246
if "mode_index" not in data.dims or "f" not in data.dims:
22212247
continue
2222-
dims_orig = data.dims
2223-
f_coord = data.coords["f"]
2224-
slices = []
2225-
for ifreq in range(num_freqs):
2226-
sl = data.isel(f=ifreq, mode_index=sort_inds_2d[ifreq])
2227-
slices.append(sl.assign_coords(mode_index=np.arange(num_modes)))
2228-
# Concatenate along the 'f' dimension name and then restore original frequency coordinates
2229-
data = xr.concat(slices, dim="f").assign_coords(f=f_coord).transpose(*dims_orig)
2230-
modify_data[key] = data
2248+
2249+
dims_orig = tuple(data.dims)
2250+
# Preserve coords (as numpy)
2251+
coords_out = {
2252+
k: (v.values if hasattr(v, "values") else np.asarray(v))
2253+
for k, v in data.coords.items()
2254+
}
2255+
f_axis = data.get_axis_num("f")
2256+
m_axis = data.get_axis_num("mode_index")
2257+
2258+
# Move axes so array is (..., f, mode)
2259+
move_order = [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [
2260+
f_axis,
2261+
m_axis,
2262+
]
2263+
arr = np.moveaxis(data.data, move_order, range(data.ndim))
2264+
lead_shape = arr.shape[:-2]
2265+
nf, nm = arr.shape[-2], arr.shape[-1]
2266+
if nf != num_freqs or nm != num_modes:
2267+
raise DataError(
2268+
"sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2269+
)
2270+
2271+
# Vectorized gather: reshape to (nf, Nlead, nm), gather along last axis
2272+
arr3 = arr.reshape((-1, nf, nm)).transpose(1, 0, 2) # (nf, Nlead, nm)
2273+
inds = sort_inds_2d[:, None, :] # (nf, 1, nm)
2274+
arr3_sorted = np.take_along_axis(arr3, inds, axis=2)
2275+
arr_sorted = arr3_sorted.transpose(1, 0, 2).reshape(*lead_shape, nf, nm)
2276+
2277+
# Move axes back to original order
2278+
arr_sorted = np.moveaxis(arr_sorted, range(data.ndim), move_order)
2279+
2280+
# Update coords: keep f, reset mode_index to 0..num_modes-1
2281+
coords_out["mode_index"] = new_mode_index_coord
2282+
coords_out["f"] = data.coords["f"].values
2283+
2284+
modify_data[key] = DataArray(arr_sorted, coords=coords_out, dims=dims_orig)
2285+
22312286
return self.updated_copy(**modify_data)
22322287

22332288
def sort_modes(

tidy3d/components/mode/mode_solver.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,7 @@ def _colocate_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
13341334
mode_solver_monitor = self.to_mode_solver_monitor(name=MODE_MONITOR_NAME)
13351335
grid_expanded = self.simulation.discretize_monitor(mode_solver_monitor)
13361336
data_dict_colocated.update({"monitor": mode_solver_monitor, "grid_expanded": grid_expanded})
1337-
mode_solver_data = mode_solver_data._updated(update=data_dict_colocated)
1338-
1337+
mode_solver_data = mode_solver_data.updated_copy(**data_dict_colocated, deep=False)
13391338
return mode_solver_data
13401339

13411340
def _normalize_modes(self, mode_solver_data: ModeSolverData):

0 commit comments

Comments
 (0)