Skip to content

Conversation

@momchil-flex
Copy link
Collaborator

@momchil-flex momchil-flex commented Oct 24, 2025

As discussed here I made some changes to speed up the overlap sort.

  • In that comment, I mention that my recent sorting updates did not degrade performance, but I was actually wrong (I was testing without those...) The new reorder I introduces no longer used in-place operations on the data arrays directly, which is more robust, but it was also slower. So I replaced this with a function that still acts on the data directly, but unfortunately to keep the robustness this introduces some complexity in rearranging dimensions.
  • Apply symmetry expansion before starting the frequency tracking to avoid doing it repeatedly.
  • Add fast pathways in some methods when coordinates already match (as would be the case for doing dot products of mode data with itself at a different frequency or mode index).
  • Remove monitor_data._updated as we now support updated_copy(deep=False), which is both better documented and seems to work faster.

Greptile Overview

Updated On: 2025-10-24 13:20:46 UTC

Greptile Summary

Optimized mode data operations by replacing the internal _updated() method with direct calls to updated_copy(deep=False, validate=False) to avoid unnecessary deep copying and validation overhead.

Key changes:

  • Removed _updated() method from MonitorData class which used dict() + parse_obj() approach
  • Replaced with updated_copy() using deep=False and validate=False flags in 5 locations across monitor_data.py and 1 location in mode_solver.py
  • Added early-exit optimization in _interpolated_tangential_fields() that checks if coordinates already match using np.allclose() before interpolation
  • Improved overlap_sort() efficiency by using symmetry_expanded data once and removing unnecessary coordinate reassignment via _assign_coords()

The changes maintain the same shallow-copy semantics while using the more standard updated_copy() API with appropriate flags, resulting in cleaner code and improved performance.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations around the early-exit optimization
  • The changes are well-focused performance optimizations that replace a custom _updated() method with the standard updated_copy() API. The refactoring maintains the same shallow-copy semantics and improves code consistency. However, the new early-exit optimization in _interpolated_tangential_fields() uses np.allclose() for coordinate comparison, which is appropriate for floating-point comparisons but should be monitored to ensure it doesn't skip interpolation when coordinates are very close but not identical enough for the use case.
  • Pay attention to monitor_data.py:800-808 where the new coordinate matching logic was added

Important Files Changed

File Analysis

Filename Score Overview
tidy3d/components/data/monitor_data.py 4/5 Replaced _updated() method with updated_copy(deep=False, validate=False) for performance; added early-exit optimization in _interpolated_tangential_fields(); optimized overlap_sort() to use symmetry_expanded and removed unnecessary coordinate reassignment
tidy3d/components/mode/mode_solver.py 5/5 Replaced _updated() method call with updated_copy(deep=False, validate=False) for consistency with monitor_data.py changes

Sequence Diagram

sequenceDiagram
    participant Client
    participant ModeData
    participant AbstractFieldData
    participant ModeSolver
    
    Note over ModeData: overlap_sort() optimization
    Client->>ModeData: overlap_sort(track_freq, overlap_thresh)
    ModeData->>AbstractFieldData: symmetry_expanded
    AbstractFieldData-->>ModeData: data_expanded (self or new copy)
    
    alt No symmetry (returns self)
        Note over AbstractFieldData: Returns self directly
    else Has symmetry
        AbstractFieldData->>AbstractFieldData: updated_copy(deep=False, validate=False)
        Note over AbstractFieldData: Creates shallow copy with expanded fields
    end
    
    ModeData->>ModeData: dot(data_expanded, conjugate)
    Note over ModeData: Compute self-overlap
    
    loop For each frequency direction
        ModeData->>ModeData: _isel(f=[freq_id])
        Note over ModeData: Extract frequency slice
        ModeData->>ModeData: _find_ordering_one_freq()
        Note over ModeData: Uses dot() instead of outer_dot()
    end
    
    ModeData->>ModeData: updated_copy(deep=False, validate=False)
    ModeData-->>Client: Sorted mode data
    
    Note over ModeSolver: _colocate_data() optimization
    Client->>ModeSolver: _colocate_data(mode_solver_data)
    ModeSolver->>ModeSolver: updated_copy(deep=False, validate=False)
    Note over ModeSolver: Replaces _updated() method
    ModeSolver-->>Client: Colocated data
    
    Note over AbstractFieldData: _interpolated_tangential_fields() optimization
    Client->>AbstractFieldData: _interpolated_tangential_fields(coords)
    AbstractFieldData->>AbstractFieldData: Check if coords match using np.allclose()
    
    alt Coords already match
        AbstractFieldData-->>Client: Return fields directly (early exit)
    else Coords don't match
        AbstractFieldData->>AbstractFieldData: Interpolate fields
        AbstractFieldData-->>Client: Return interpolated fields
    end
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link
Contributor

github-actions bot commented Oct 27, 2025

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/data/monitor_data.py (98.2%): Missing lines 2269
  • tidy3d/components/mode/mode_solver.py (100%)

Summary

  • Total: 56 lines
  • Missing: 1 line
  • Coverage: 98%

tidy3d/components/data/monitor_data.py

Lines 2265-2273

  2265             )
  2266             arr = np.moveaxis(data.data, src_order, range(data.ndim))
  2267             nf, nm = arr.shape[0], arr.shape[-1]
  2268             if nf != num_freqs or nm != num_modes:
! 2269                 raise DataError(
  2270                     "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
  2271                 )
  2272 
  2273             # Apply sorting

@momchil-flex momchil-flex force-pushed the momchil/speed_up_overlap_sort branch 2 times, most recently from c9566b8 to ef492c1 Compare October 27, 2025 12:54
@momchil-flex
Copy link
Collaborator Author

@greptile give this another go buddy

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code and the changes, I can see you're trying to optimize mode data operations. Let me provide a concrete suggestion for the section that needs fixing.

The issue is in the _apply_mode_reorder method where you're trying to use np.take_along_axis with broadcasting, but the shapes don't align correctly. Here's the fix:

        sort_inds_2d = np.asarray(sort_inds_2d, dtype=int)
        num_freqs, num_modes = sort_inds_2d.shape

        # Fast no-op
        identity = np.arange(num_modes)
        if np.all(sort_inds_2d == identity[None, :]):
            return self

        modify_data = {}
        new_mode_index_coord = identity

        for key, data in self.data_arrs.items():
            if "mode_index" not in data.dims or "f" not in data.dims:
                continue

            dims_orig = tuple(data.dims)
            # Preserve coords (as numpy)
            coords_out = {
                k: (v.values if hasattr(v, "values") else np.asarray(v))
                for k, v in data.coords.items()
            }
            f_axis = data.get_axis_num("f")
            m_axis = data.get_axis_num("mode_index")

            # Move axes so array is (..., f, mode)
            move_order = [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [
                f_axis,
                m_axis,
            ]
            arr = np.moveaxis(data.values, move_order, range(data.ndim))
            lead_shape = arr.shape[:-2]
            nf, nm = arr.shape[-2], arr.shape[-1]
            if nf != num_freqs or nm != num_modes:
                raise DataError(
                    "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
                )

            # Vectorized gather: reshape to (Nlead, nf, nm), gather along last axis
            arr_reshaped = arr.reshape((-1, nf, nm))
            # Expand sort_inds_2d to (Nlead, nf, nm)
            inds = np.broadcast_to(sort_inds_2d[None, :, :], arr_reshaped.shape)
            arr_sorted_flat = np.take_along_axis(arr_reshaped, inds, axis=2)
            arr_sorted = arr_sorted_flat.reshape(*lead_shape, nf, nm)

            # Move axes back to original order
            arr_sorted = np.moveaxis(arr_sorted, range(data.ndim), move_order)

            # Update coords: keep f, reset mode_index to 0..num_modes-1
            coords_out["mode_index"] = new_mode_index_coord
            coords_out["f"] = data.coords["f"].values

            modify_data[key] = DataArray(arr_sorted, coords=coords_out, dims=dims_orig)

        return self.updated_copy(**modify_data)

The key changes:

  1. Changed data.data to data.values for consistency
  2. Fixed the reshaping to (Nlead, nf, nm) instead of transposing
  3. Used np.broadcast_to to expand sort_inds_2d to match the full array shape
  4. Applied np.take_along_axis with proper broadcasting

Copy link
Contributor

@dbochkov-flexcompute dbochkov-flexcompute left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a couple minor (?) comments

@momchil-flex momchil-flex force-pushed the momchil/speed_up_overlap_sort branch from b972ffa to 884f8b4 Compare October 28, 2025 10:10
Copy link
Contributor

@dmarek-flex dmarek-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but I do think trying to avoid those try/except usages would be better.

@momchil-flex momchil-flex force-pushed the momchil/speed_up_overlap_sort branch from 13e6e94 to c5e9bf8 Compare October 31, 2025 13:04
@momchil-flex
Copy link
Collaborator Author

Alright, I removed the try: except logic and made a standalone method _tangential_fields_match_coords to test if a given field data's tangential coords match a provided set. Now in dot I just test if the shapes are the same, because the tangential coordinates have to be the same after the call to the interpolation. Previously, I was attempting to fast-path the computation even when the shapes are not the same (broadcasting) but that was actually not always working and it was masked by the try - except. Doing the broadcasting with numpy is pretty nontrivial because we might have to do it on freqs and/or mode indexes, so I just fast-track if shapes are identical, and let xarray handle the more complicated cases for robustness.

Thanks for the useful suggestions!

@momchil-flex momchil-flex added this pull request to the merge queue Oct 31, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 31, 2025
@momchil-flex momchil-flex added this pull request to the merge queue Oct 31, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants