Skip to content

Commit 6ad9cb1

Browse files
committed
Allow passing single elements to component graph connections()
Like with component graph `components()`, now we support passing a single element to `matching_sources` and `matching_destinations`. The `components()` function is also simplified to make filtering more compact. Also updates uses with a single element to remove the unnecessary container. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 1fe14db commit 6ad9cb1

File tree

2 files changed

+39
-40
lines changed

2 files changed

+39
-40
lines changed

src/frequenz/sdk/microgrid/component_graph.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import asyncio
2525
import logging
2626
from abc import ABC, abstractmethod
27-
from collections.abc import Callable, Iterable, Set
27+
from collections.abc import Callable, Iterable
2828

2929
import networkx as nx
3030
from frequenz.client.common.microgrid.components import ComponentId
@@ -83,8 +83,8 @@ def components(
8383
@abstractmethod
8484
def connections(
8585
self,
86-
matching_sources: set[ComponentId] | None = None,
87-
matching_destinations: set[ComponentId] | None = None,
86+
matching_sources: Iterable[ComponentId] | ComponentId | None = None,
87+
matching_destinations: Iterable[ComponentId] | ComponentId | None = None,
8888
) -> set[ComponentConnection]:
8989
"""Fetch the connections between microgrid components.
9090
@@ -399,27 +399,15 @@ def components(
399399
The set of components currently connected to the microgrid, filtered by
400400
the provided `matching_ids` and `matching_types` values.
401401
"""
402-
match matching_ids:
403-
case ComponentId():
404-
matching_ids = {matching_ids}
405-
case Set():
406-
pass
407-
case Iterable():
408-
matching_ids = set(matching_ids)
409-
410-
match matching_types:
411-
case type():
412-
matching_types = {matching_types}
413-
case Set():
414-
pass
415-
case Iterable():
416-
matching_types = set(matching_types)
402+
matching_ids = _comp_ids_to_iter(matching_ids)
403+
if isinstance(matching_types, type):
404+
matching_types = {matching_types}
417405

418406
selection: Iterable[Component]
419407
selection_ids = (
420408
self._graph.nodes
421409
if matching_ids is None
422-
else matching_ids & self._graph.nodes
410+
else set(matching_ids) & self._graph.nodes
423411
)
424412
selection = (self._graph.nodes[i][_DATA_KEY] for i in selection_ids)
425413

@@ -433,8 +421,8 @@ def components(
433421
@override
434422
def connections(
435423
self,
436-
matching_sources: set[ComponentId] | None = None,
437-
matching_destinations: set[ComponentId] | None = None,
424+
matching_sources: Iterable[ComponentId] | ComponentId | None = None,
425+
matching_destinations: Iterable[ComponentId] | ComponentId | None = None,
438426
) -> set[ComponentConnection]:
439427
"""Fetch the connections between microgrid components.
440428
@@ -447,6 +435,9 @@ def connections(
447435
The set of connections between components in the microgrid, filtered by
448436
the provided `matching_sources` and `matching_destinations` choices.
449437
"""
438+
matching_sources = _comp_ids_to_iter(matching_sources)
439+
matching_destinations = _comp_ids_to_iter(matching_destinations)
440+
450441
match (matching_sources, matching_destinations):
451442
case (None, None):
452443
selection = self._graph.edges
@@ -1128,3 +1119,11 @@ def _validate_leaf_components(self) -> None:
11281119
raise InvalidGraphError(
11291120
f"Leaf components with graph successors: {with_successors}"
11301121
)
1122+
1123+
1124+
def _comp_ids_to_iter(
1125+
ids: Iterable[ComponentId] | ComponentId | None,
1126+
) -> Iterable[ComponentId] | None:
1127+
if isinstance(ids, ComponentId):
1128+
return (ids,)
1129+
return ids

tests/microgrid/test_graph.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -389,18 +389,18 @@ def test_connection_filters(self) -> None:
389389

390390
# with start filter applied, we get back only connections whose `start`
391391
# component matches one of the provided IDs
392-
assert graph.connections(matching_sources={ComponentId(8)}) == set()
393-
assert graph.connections(matching_sources={ComponentId(7)}) == set()
394-
assert graph.connections(matching_sources={ComponentId(6)}) == set()
395-
assert graph.connections(matching_sources={ComponentId(5)}) == set()
396-
assert graph.connections(matching_sources={ComponentId(4)}) == set()
397-
assert graph.connections(matching_sources={ComponentId(3)}) == set()
398-
assert graph.connections(matching_sources={ComponentId(2)}) == {
392+
assert graph.connections(matching_sources=ComponentId(8)) == set()
393+
assert graph.connections(matching_sources=ComponentId(7)) == set()
394+
assert graph.connections(matching_sources=ComponentId(6)) == set()
395+
assert graph.connections(matching_sources=ComponentId(5)) == set()
396+
assert graph.connections(matching_sources=ComponentId(4)) == set()
397+
assert graph.connections(matching_sources=ComponentId(3)) == set()
398+
assert graph.connections(matching_sources=ComponentId(2)) == {
399399
ComponentConnection(source=ComponentId(2), destination=ComponentId(4)),
400400
ComponentConnection(source=ComponentId(2), destination=ComponentId(5)),
401401
ComponentConnection(source=ComponentId(2), destination=ComponentId(6)),
402402
}
403-
assert graph.connections(matching_sources={ComponentId(1)}) == {
403+
assert graph.connections(matching_sources=ComponentId(1)) == {
404404
ComponentConnection(source=ComponentId(1), destination=ComponentId(2)),
405405
ComponentConnection(source=ComponentId(1), destination=ComponentId(3)),
406406
}
@@ -427,23 +427,23 @@ def test_connection_filters(self) -> None:
427427

428428
# with end filter applied, we get back only connections whose `end`
429429
# component matches one of the provided IDs
430-
assert graph.connections(matching_destinations={ComponentId(8)}) == set()
431-
assert graph.connections(matching_destinations={ComponentId(6)}) == {
430+
assert graph.connections(matching_destinations=ComponentId(8)) == set()
431+
assert graph.connections(matching_destinations=ComponentId(6)) == {
432432
ComponentConnection(source=ComponentId(2), destination=ComponentId(6))
433433
}
434-
assert graph.connections(matching_destinations={ComponentId(5)}) == {
434+
assert graph.connections(matching_destinations=ComponentId(5)) == {
435435
ComponentConnection(source=ComponentId(2), destination=ComponentId(5))
436436
}
437-
assert graph.connections(matching_destinations={ComponentId(4)}) == {
437+
assert graph.connections(matching_destinations=ComponentId(4)) == {
438438
ComponentConnection(source=ComponentId(2), destination=ComponentId(4))
439439
}
440-
assert graph.connections(matching_destinations={ComponentId(3)}) == {
440+
assert graph.connections(matching_destinations=ComponentId(3)) == {
441441
ComponentConnection(source=ComponentId(1), destination=ComponentId(3))
442442
}
443-
assert graph.connections(matching_destinations={ComponentId(2)}) == {
443+
assert graph.connections(matching_destinations=ComponentId(2)) == {
444444
ComponentConnection(source=ComponentId(1), destination=ComponentId(2))
445445
}
446-
assert graph.connections(matching_destinations={ComponentId(1)}) == set()
446+
assert graph.connections(matching_destinations=ComponentId(1)) == set()
447447
assert graph.connections(
448448
matching_destinations={ComponentId(1), ComponentId(2), ComponentId(3)}
449449
) == {
@@ -470,18 +470,18 @@ def test_connection_filters(self) -> None:
470470
ComponentConnection(source=ComponentId(2), destination=ComponentId(4)),
471471
ComponentConnection(source=ComponentId(2), destination=ComponentId(6)),
472472
}
473-
assert graph.connections(matching_destinations={ComponentId(1)}) == set()
473+
assert graph.connections(matching_destinations=ComponentId(1)) == set()
474474

475475
# when both filters are applied, they are combined via AND logic, i.e.
476476
# a connection must have its `start` matching one of the provided start
477477
# values, and its `end` matching one of the provided end values
478478
assert graph.connections(
479-
matching_sources={ComponentId(1)}, matching_destinations={ComponentId(2)}
479+
matching_sources=ComponentId(1), matching_destinations=ComponentId(2)
480480
) == {ComponentConnection(source=ComponentId(1), destination=ComponentId(2))}
481481
assert (
482482
graph.connections(
483-
matching_sources={ComponentId(2)},
484-
matching_destinations={ComponentId(3)},
483+
matching_sources=ComponentId(2),
484+
matching_destinations=ComponentId(3),
485485
)
486486
== set()
487487
)

0 commit comments

Comments
 (0)