Skip to content

Commit 427758e

Browse files
authored
Filter excluded states in entity trigger base class (home-assistant#169956)
1 parent c2ce313 commit 427758e

10 files changed

Lines changed: 57 additions & 113 deletions

File tree

homeassistant/components/counter/trigger.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
"""Provides triggers for counters."""
22

3-
from homeassistant.const import (
4-
CONF_MAXIMUM,
5-
CONF_MINIMUM,
6-
STATE_UNAVAILABLE,
7-
STATE_UNKNOWN,
8-
)
3+
from homeassistant.const import CONF_MAXIMUM, CONF_MINIMUM
94
from homeassistant.core import HomeAssistant, State
105
from homeassistant.helpers.automation import DomainSpec
116
from homeassistant.helpers.trigger import (
@@ -41,19 +36,15 @@ class CounterDecrementedTrigger(CounterBaseIntegerTrigger):
4136
"""Trigger for when a counter is decremented."""
4237

4338
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
44-
"""Check if the origin state is valid and the state has changed."""
45-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
46-
return False
39+
"""Check that the counter value decreased."""
4740
return int(from_state.state) > int(to_state.state)
4841

4942

5043
class CounterIncrementedTrigger(CounterBaseIntegerTrigger):
5144
"""Trigger for when a counter is incremented."""
5245

5346
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
54-
"""Check if the origin state is valid and the state has changed."""
55-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
56-
return False
47+
"""Check that the counter value increased."""
5748
return int(from_state.state) < int(to_state.state)
5849

5950

@@ -62,12 +53,6 @@ class CounterValueBaseTrigger(EntityTriggerBase):
6253

6354
_domain_specs = {DOMAIN: DomainSpec()}
6455

65-
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
66-
"""Check if the origin state is valid and the state has changed."""
67-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
68-
return False
69-
return from_state.state != to_state.state
70-
7156

7257
class CounterMaxReachedTrigger(CounterValueBaseTrigger):
7358
"""Trigger for when a counter reaches its maximum value."""

homeassistant/components/cover/trigger.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Mapping
44

5-
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
5+
from homeassistant.const import STATE_OFF, STATE_ON
66
from homeassistant.core import HomeAssistant, State
77
from homeassistant.helpers.trigger import EntityTriggerBase, Trigger
88

@@ -28,9 +28,7 @@ def is_valid_state(self, state: State) -> bool:
2828
return self._get_value(state) == domain_spec.target_value
2929

3030
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
31-
"""Check if the transition is valid for a cover state change."""
32-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
33-
return False
31+
"""Check that the relevant cover value changed."""
3432
if (from_value := self._get_value(from_state)) is None:
3533
return False
3634
return from_value != self._get_value(to_state)

homeassistant/components/doorbell/trigger.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ class DoorbellRangTrigger(StatelessEntityTriggerBase):
1717
_domain_specs = {EVENT_DOMAIN: DomainSpec(device_class=EventDeviceClass.DOORBELL)}
1818

1919
def is_valid_state(self, state: State) -> bool:
20-
"""Check if the entity is available and the event type is ring."""
21-
return super().is_valid_state(state) and (
22-
state.attributes.get(ATTR_EVENT_TYPE) == DoorbellEventType.RING
23-
)
20+
"""Check if the event type is ring."""
21+
return state.attributes.get(ATTR_EVENT_TYPE) == DoorbellEventType.RING
2422

2523

2624
TRIGGERS: dict[str, type[Trigger]] = {

homeassistant/components/event/trigger.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
4141

4242
def is_valid_state(self, state: State) -> bool:
4343
"""Check if the event type matches one of the configured types."""
44-
return super().is_valid_state(state) and (
45-
state.attributes.get(ATTR_EVENT_TYPE) in self._event_types
46-
)
44+
return state.attributes.get(ATTR_EVENT_TYPE) in self._event_types
4745

4846

4947
TRIGGERS: dict[str, type[Trigger]] = {

homeassistant/components/media_player/trigger.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Provides triggers for media players."""
22

3-
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
43
from homeassistant.core import HomeAssistant, State
54
from homeassistant.helpers.automation import DomainSpec
65
from homeassistant.helpers.trigger import (
@@ -50,10 +49,7 @@ def is_muted(self, state: State) -> bool:
5049
)
5150

5251
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
53-
"""Check if the origin state is valid and the state has changed."""
54-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
55-
return False
56-
52+
"""Check that the muted-state changed."""
5753
if not self._has_volume_attributes(to_state):
5854
return False
5955

homeassistant/components/schedule/trigger.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Provides triggers for schedules."""
22

3-
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
3+
from homeassistant.const import STATE_OFF, STATE_ON
44
from homeassistant.core import HomeAssistant, State
55
from homeassistant.helpers.automation import DomainSpec
66
from homeassistant.helpers.trigger import (
@@ -20,10 +20,7 @@ class ScheduleBackToBackTrigger(EntityTransitionTriggerBase):
2020
_to_states = {STATE_ON}
2121

2222
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
23-
"""Check if the origin state matches the expected ones."""
24-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
25-
return False
26-
23+
"""Check that the origin matches and the next event changed."""
2724
from_next_event = from_state.attributes.get(ATTR_NEXT_EVENT)
2825
to_next_event = to_state.attributes.get(ATTR_NEXT_EVENT)
2926

homeassistant/components/select/trigger.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Provides triggers for selects."""
22

33
from homeassistant.components.input_select import DOMAIN as INPUT_SELECT_DOMAIN
4-
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
5-
from homeassistant.core import HomeAssistant, State
4+
from homeassistant.core import HomeAssistant
65
from homeassistant.helpers.automation import DomainSpec
76
from homeassistant.helpers.trigger import (
87
ENTITY_STATE_TRIGGER_SCHEMA,
@@ -19,16 +18,6 @@ class SelectionChangedTrigger(EntityTriggerBase):
1918
_domain_specs = {DOMAIN: DomainSpec(), INPUT_SELECT_DOMAIN: DomainSpec()}
2019
_schema = ENTITY_STATE_TRIGGER_SCHEMA
2120

22-
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
23-
"""Check if the origin state is valid and the state has changed."""
24-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
25-
return False
26-
return from_state.state != to_state.state
27-
28-
def is_valid_state(self, state: State) -> bool:
29-
"""Check if the new state is not invalid."""
30-
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
31-
3221

3322
TRIGGERS: dict[str, type[Trigger]] = {
3423
"selection_changed": SelectionChangedTrigger,

homeassistant/components/text/trigger.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Provides triggers for text and input_text entities."""
22

33
from homeassistant.components.input_text import DOMAIN as INPUT_TEXT_DOMAIN
4-
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
5-
from homeassistant.core import HomeAssistant, State
4+
from homeassistant.core import HomeAssistant
65
from homeassistant.helpers.automation import DomainSpec
76
from homeassistant.helpers.trigger import (
87
ENTITY_STATE_TRIGGER_SCHEMA,
@@ -19,16 +18,6 @@ class TextChangedTrigger(EntityTriggerBase):
1918
_domain_specs = {DOMAIN: DomainSpec(), INPUT_TEXT_DOMAIN: DomainSpec()}
2019
_schema = ENTITY_STATE_TRIGGER_SCHEMA
2120

22-
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
23-
"""Check if the origin state is valid and the state has changed."""
24-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
25-
return False
26-
return from_state.state != to_state.state
27-
28-
def is_valid_state(self, state: State) -> bool:
29-
"""Check if the new state is not invalid."""
30-
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
31-
3221

3322
TRIGGERS: dict[str, type[Trigger]] = {
3423
"changed": TextChangedTrigger,

homeassistant/helpers/trigger.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,14 @@ class EntityTriggerBase(Trigger):
353353
"""Trigger for entity state changes."""
354354

355355
_domain_specs: Mapping[str, DomainSpec]
356+
# States filtered from the to_state pre-filter (and `_should_include`).
356357
_excluded_states: Final[frozenset[str]] = frozenset(
357358
{STATE_UNAVAILABLE, STATE_UNKNOWN}
358359
)
360+
# States filtered from the from_state pre-filter. Defaults to
361+
# `_excluded_states`. Subclasses can override to relax the origin
362+
# check.
363+
_excluded_from_states: ClassVar[frozenset[str]] = _excluded_states
359364
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST
360365
# When True, indirect target expansion (via device/area/floor) skips
361366
# entities with an entity_category.
@@ -389,13 +394,28 @@ def _get_tracked_value(self, state: State) -> Any:
389394
return state.state
390395
return state.attributes.get(domain_spec.value_source)
391396

392-
@abc.abstractmethod
393397
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
394-
"""Check if the origin state is valid and the state has changed."""
398+
"""Check if the transition should fire the trigger.
399+
400+
Called only after `from_state.state` has been filtered against
401+
`_excluded_from_states` and `to_state.state` against
402+
`_excluded_states`, so subclasses don't need to repeat those
403+
checks. Default: any state change. Override to add semantics
404+
(specific from/to states, value changed across a threshold,
405+
etc.).
406+
"""
407+
return from_state.state != to_state.state
395408

396-
@abc.abstractmethod
397409
def is_valid_state(self, state: State) -> bool:
398-
"""Check if the new state matches the expected state(s)."""
410+
"""Check if the state is a target state for the trigger.
411+
412+
Called only after `state.state` has been filtered against
413+
`_excluded_states`, so subclasses don't need to repeat that
414+
check. Default: any non-excluded state is a target. Override
415+
to restrict (specific to_states, value within a threshold,
416+
etc.).
417+
"""
418+
return True
399419

400420
def _should_include(self, state: State) -> bool:
401421
"""Check if an entity should participate in all/count checks.
@@ -473,19 +493,26 @@ def state_still_valid(
473493
)
474494
return matches >= 1
475495
# Behavior any: check the individual entity's state
476-
if not to_state:
496+
if not to_state or to_state.state in self._excluded_states:
477497
return False
478498
return self.is_valid_state(to_state)
479499

480500
if not from_state or not to_state:
481501
return
482502

483-
# The trigger should never fire if the new state is not valid
484-
if not self.is_valid_state(to_state):
503+
# The trigger should never fire if the new state is excluded
504+
# or not a target state.
505+
if to_state.state in self._excluded_states or not self.is_valid_state(
506+
to_state
507+
):
485508
return
486509

487-
# The trigger should never fire if the transition is not valid
488-
if not self.is_valid_transition(from_state, to_state):
510+
# The trigger should never fire if the origin state is excluded
511+
# or the transition is not valid.
512+
if (
513+
from_state.state in self._excluded_from_states
514+
or not self.is_valid_transition(from_state, to_state)
515+
):
489516
return
490517

491518
if behavior == BEHAVIOR_LAST:
@@ -570,10 +597,7 @@ class EntityTargetStateTriggerBase(EntityTriggerBase):
570597
_to_states: set[str]
571598

572599
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
573-
"""Check if the origin state is valid and the state has changed."""
574-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
575-
return False
576-
600+
"""Check the value changed and the origin was not already a target state."""
577601
from_value = self._get_tracked_value(from_state)
578602
return (
579603
from_value != self._get_tracked_value(to_state)
@@ -593,9 +617,6 @@ class EntityTransitionTriggerBase(EntityTriggerBase):
593617

594618
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
595619
"""Check if the origin state matches the expected ones."""
596-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
597-
return False
598-
599620
from_value = self._get_tracked_value(from_state)
600621
return (
601622
from_value != self._get_tracked_value(to_state)
@@ -620,34 +641,21 @@ def is_valid_transition(self, from_state: State, to_state: State) -> bool:
620641
)
621642

622643
def is_valid_state(self, state: State) -> bool:
623-
"""Check if the new state is valid."""
624-
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) and bool(
625-
self._get_tracked_value(state) != self._from_state
626-
)
644+
"""Check that the new state is different from the origin state."""
645+
return bool(self._get_tracked_value(state) != self._from_state)
627646

628647

629648
class StatelessEntityTriggerBase(EntityTriggerBase):
630649
"""Trigger for entities that don't carry meaningful state.
631650
632651
Used for stateless entities (buttons, scenes, doorbells, events)
633652
whose `state.state` is just a timestamp of the last activation.
653+
`STATE_UNKNOWN` is a legitimate prior state — the first activation
654+
after startup must still fire the trigger.
634655
"""
635656

636657
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA
637-
638-
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
639-
"""Check if the origin state is available and the state has changed.
640-
641-
STATE_UNKNOWN is allowed as the origin state so the first
642-
activation fires.
643-
"""
644-
if from_state.state == STATE_UNAVAILABLE:
645-
return False
646-
return from_state.state != to_state.state
647-
648-
def is_valid_state(self, state: State) -> bool:
649-
"""Check that the entity has been activated at least once."""
650-
return state.state not in self._excluded_states
658+
_excluded_from_states: ClassVar[frozenset[str]] = frozenset({STATE_UNAVAILABLE})
651659

652660

653661
NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA = ENTITY_STATE_TRIGGER_SCHEMA.extend(
@@ -826,10 +834,7 @@ class EntityNumericalStateChangedTriggerBase(EntityNumericalStateTriggerBase):
826834
_schema = NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA
827835

828836
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
829-
"""Check if the origin state is valid and the state has changed."""
830-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
831-
return False
832-
837+
"""Check if the tracked numeric value has changed."""
833838
return self._get_tracked_value(from_state) != self._get_tracked_value(to_state)
834839

835840

@@ -888,10 +893,7 @@ class EntityNumericalStateCrossedThresholdTriggerBase(EntityNumericalStateTrigge
888893
_schema = NUMERICAL_ATTRIBUTE_CROSSED_THRESHOLD_SCHEMA
889894

890895
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
891-
"""Check if the origin state is valid and the state has changed."""
892-
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
893-
return False
894-
896+
"""Check that the tracked value crossed into the threshold range."""
895897
return not self.is_valid_state(from_state)
896898

897899

tests/helpers/test_trigger.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,10 +2969,6 @@ async def test_make_entity_target_state_trigger(
29692969
# Value did not change — not a valid transition
29702970
assert not trig.is_valid_transition(from_state, from_state)
29712971

2972-
# From unavailable — not valid
2973-
unavailable = State("light.bed", STATE_UNAVAILABLE, {})
2974-
assert not trig.is_valid_transition(unavailable, to_state)
2975-
29762972
# Value not in to_states — not valid
29772973
assert not trig.is_valid_state(wrong_value_state)
29782974

@@ -3043,10 +3039,6 @@ async def test_make_entity_transition_trigger(
30433039
# No change in tracked value — not a valid transition
30443040
assert not trig.is_valid_transition(from_state, from_state)
30453041

3046-
# From unavailable — not valid
3047-
unavailable = State("climate.living", STATE_UNAVAILABLE, {})
3048-
assert not trig.is_valid_transition(unavailable, to_state)
3049-
30503042

30513043
@pytest.mark.parametrize(
30523044
("domain_specs", "origin", "from_state", "to_state", "wrong_from"),

0 commit comments

Comments
 (0)