- 
                Notifications
    You must be signed in to change notification settings 
- Fork 65
FXC-3351-Updates to speed up some mode data operations #2927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
 Summary
 tidy3d/components/data/monitor_data.pyLines 2265-2273   2265             )
  2266             arr = np.moveaxis(data.data, src_order, range(data.ndim))
  2267             nf, nm = arr.shape[0], arr.shape[-1]
  2268             if nf != num_freqs or nm != num_modes:
! 2269                 raise DataError(
  2270                     "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
  2271                 )
  2272 
  2273             # Apply sorting | 
c9566b8    to
    ef492c1      
    Compare
  
    | @greptile give this another go buddy | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the code and the changes, I can see you're trying to optimize mode data operations. Let me provide a concrete suggestion for the section that needs fixing.
The issue is in the _apply_mode_reorder method where you're trying to use np.take_along_axis with broadcasting, but the shapes don't align correctly. Here's the fix:
        sort_inds_2d = np.asarray(sort_inds_2d, dtype=int)
        num_freqs, num_modes = sort_inds_2d.shape
        # Fast no-op
        identity = np.arange(num_modes)
        if np.all(sort_inds_2d == identity[None, :]):
            return self
        modify_data = {}
        new_mode_index_coord = identity
        for key, data in self.data_arrs.items():
            if "mode_index" not in data.dims or "f" not in data.dims:
                continue
            dims_orig = tuple(data.dims)
            # Preserve coords (as numpy)
            coords_out = {
                k: (v.values if hasattr(v, "values") else np.asarray(v))
                for k, v in data.coords.items()
            }
            f_axis = data.get_axis_num("f")
            m_axis = data.get_axis_num("mode_index")
            # Move axes so array is (..., f, mode)
            move_order = [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [
                f_axis,
                m_axis,
            ]
            arr = np.moveaxis(data.values, move_order, range(data.ndim))
            lead_shape = arr.shape[:-2]
            nf, nm = arr.shape[-2], arr.shape[-1]
            if nf != num_freqs or nm != num_modes:
                raise DataError(
                    "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
                )
            # Vectorized gather: reshape to (Nlead, nf, nm), gather along last axis
            arr_reshaped = arr.reshape((-1, nf, nm))
            # Expand sort_inds_2d to (Nlead, nf, nm)
            inds = np.broadcast_to(sort_inds_2d[None, :, :], arr_reshaped.shape)
            arr_sorted_flat = np.take_along_axis(arr_reshaped, inds, axis=2)
            arr_sorted = arr_sorted_flat.reshape(*lead_shape, nf, nm)
            # Move axes back to original order
            arr_sorted = np.moveaxis(arr_sorted, range(data.ndim), move_order)
            # Update coords: keep f, reset mode_index to 0..num_modes-1
            coords_out["mode_index"] = new_mode_index_coord
            coords_out["f"] = data.coords["f"].values
            modify_data[key] = DataArray(arr_sorted, coords=coords_out, dims=dims_orig)
        return self.updated_copy(**modify_data)
The key changes:
- Changed data.datatodata.valuesfor consistency
- Fixed the reshaping to (Nlead, nf, nm)instead of transposing
- Used np.broadcast_toto expandsort_inds_2dto match the full array shape
- Applied np.take_along_axiswith proper broadcasting
ef492c1    to
    b972ffa      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a couple minor (?) comments
b972ffa    to
    884f8b4      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but I do think trying to avoid those try/except usages would be better.
13e6e94    to
    c5e9bf8      
    Compare
  
    | Alright, I removed the try: except logic and made a standalone method  Thanks for the useful suggestions! | 
As discussed here I made some changes to speed up the overlap sort.
monitor_data._updatedas we now supportupdated_copy(deep=False), which is both better documented and seems to work faster.Greptile Overview
Updated On: 2025-10-24 13:20:46 UTC
Greptile Summary
Optimized mode data operations by replacing the internal
_updated()method with direct calls toupdated_copy(deep=False, validate=False)to avoid unnecessary deep copying and validation overhead.Key changes:
_updated()method fromMonitorDataclass which useddict()+parse_obj()approachupdated_copy()usingdeep=Falseandvalidate=Falseflags in 5 locations acrossmonitor_data.pyand 1 location inmode_solver.py_interpolated_tangential_fields()that checks if coordinates already match usingnp.allclose()before interpolationoverlap_sort()efficiency by usingsymmetry_expandeddata once and removing unnecessary coordinate reassignment via_assign_coords()The changes maintain the same shallow-copy semantics while using the more standard
updated_copy()API with appropriate flags, resulting in cleaner code and improved performance.Confidence Score: 4/5
_updated()method with the standardupdated_copy()API. The refactoring maintains the same shallow-copy semantics and improves code consistency. However, the new early-exit optimization in_interpolated_tangential_fields()usesnp.allclose()for coordinate comparison, which is appropriate for floating-point comparisons but should be monitored to ensure it doesn't skip interpolation when coordinates are very close but not identical enough for the use case.monitor_data.py:800-808where the new coordinate matching logic was addedImportant Files Changed
File Analysis
_updated()method withupdated_copy(deep=False, validate=False)for performance; added early-exit optimization in_interpolated_tangential_fields(); optimizedoverlap_sort()to usesymmetry_expandedand removed unnecessary coordinate reassignment_updated()method call withupdated_copy(deep=False, validate=False)for consistency with monitor_data.py changesSequence Diagram
sequenceDiagram participant Client participant ModeData participant AbstractFieldData participant ModeSolver Note over ModeData: overlap_sort() optimization Client->>ModeData: overlap_sort(track_freq, overlap_thresh) ModeData->>AbstractFieldData: symmetry_expanded AbstractFieldData-->>ModeData: data_expanded (self or new copy) alt No symmetry (returns self) Note over AbstractFieldData: Returns self directly else Has symmetry AbstractFieldData->>AbstractFieldData: updated_copy(deep=False, validate=False) Note over AbstractFieldData: Creates shallow copy with expanded fields end ModeData->>ModeData: dot(data_expanded, conjugate) Note over ModeData: Compute self-overlap loop For each frequency direction ModeData->>ModeData: _isel(f=[freq_id]) Note over ModeData: Extract frequency slice ModeData->>ModeData: _find_ordering_one_freq() Note over ModeData: Uses dot() instead of outer_dot() end ModeData->>ModeData: updated_copy(deep=False, validate=False) ModeData-->>Client: Sorted mode data Note over ModeSolver: _colocate_data() optimization Client->>ModeSolver: _colocate_data(mode_solver_data) ModeSolver->>ModeSolver: updated_copy(deep=False, validate=False) Note over ModeSolver: Replaces _updated() method ModeSolver-->>Client: Colocated data Note over AbstractFieldData: _interpolated_tangential_fields() optimization Client->>AbstractFieldData: _interpolated_tangential_fields(coords) AbstractFieldData->>AbstractFieldData: Check if coords match using np.allclose() alt Coords already match AbstractFieldData-->>Client: Return fields directly (early exit) else Coords don't match AbstractFieldData->>AbstractFieldData: Interpolate fields AbstractFieldData-->>Client: Return interpolated fields end