Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cheetah/accelerator/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,14 +587,14 @@ def get_beam_attrs_along_segment(
def get_attr_from_every_element(
self,
filter_type: type[Element] | tuple[type[Element]] | None = None,
filter_name: str | None = None,
filter_name: str | tuple[str] | None = None,
is_recursive: bool = True,
) -> list[Any]:
"""
Get an attribute from every element type in the segment filtered by type and/or
name.
:param filter_type: Type of the elements to get the attribute from.
:param filter_name: Name of the elements to get the attribute from.
:param filter_name: Name of a single element or a tuple of names to filter by.
:param is_recursive: If `True`, this method is applied to nested `Segment`s as
well. If `False`, only the elements directly in the top-level `Segment` are
considered.
Expand All @@ -603,7 +603,9 @@ def get_attr_from_every_element(
attrs = []
for element in self.elements:
if (filter_type is None or isinstance(element, filter_type)) and (
filter_name is None or element.name in filter_name
filter_name is None
or (isinstance(filter_name, str) and element.name == filter_name)
or (isinstance(filter_name, tuple) and element.name in filter_name)
):
attrs.append(element)
elif is_recursive and isinstance(element, Segment):
Expand All @@ -628,15 +630,17 @@ def set_attrs_on_every_element(
name.

:param filter_type: Type of the elements to set the attributes for.
:param filter_name: Names of the elements to set the attributes for.
:param filter_name: Name of a single element or a tuple of names to filter by.
:param is_recursive: If `True`, this method is applied to nested `Segment`s as
well. If `False`, only the elements directly in the top-level `Segment` are
considered.
:param kwargs: Attributes to set and their values.
"""
for element in self.elements:
if (filter_type is None or isinstance(element, filter_type)) and (
filter_name is None or element.name in filter_name
filter_name is None
or (isinstance(filter_name, str) and element.name == filter_name)
or (isinstance(filter_name, tuple) and element.name in filter_name)
):
for key, value in kwargs.items():
setattr(element, key, value)
Expand Down