Skip to content

Commit 9af76d2

Browse files
authored
Merge pull request #38 from airflow-laminar/tkp/tests
Add tests, fixes #22
2 parents 3688934 + 8a00940 commit 9af76d2

File tree

6 files changed

+299
-25
lines changed

6 files changed

+299
-25
lines changed

airflow_ha/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22

33
from .common import *
44
from .operator import *
5-
from .utils import *

airflow_ha/operator.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from airflow.operators.python import BranchPythonOperator, PythonOperator
77
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
88
from airflow.sensors.python import PythonSensor
9+
from airflow_common_operators import fail, pass_
910

1011
from .common import Action, CheckResult, Result
11-
from .utils import fail_, pass_
1212

1313
__all__ = ("HighAvailabilityOperator",)
1414
_log = getLogger(__name__)
@@ -114,10 +114,10 @@ def __init__(
114114

115115
# this is needed to ensure the dag fails, since the
116116
# retrigger_fail step will pass (to ensure dag retriggers!)
117-
self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail_)
117+
self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail)
118118

119119
self._stop_pass = PythonOperator(task_id=f"{self.task_id}-stop-pass", python_callable=pass_)
120-
self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail_)
120+
self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail)
121121

122122
# Update the retrigger counts in trigger kwargs
123123
retrigger_count_conf = f'''{{{{ (ti.dag_run.conf.get("{self.task_id}-retrigger", 0)|int) + 1 }}}}'''
@@ -175,6 +175,10 @@ def __init__(
175175
self._decide_task >> self._retrigger_pass
176176
self._decide_task >> self._retrigger_fail >> self._fail
177177

178+
@property
179+
def decide_task(self) -> PythonOperator:
180+
return self._decide_task
181+
178182
@property
179183
def stop_fail(self) -> PythonOperator:
180184
return self._stop_fail
@@ -238,7 +242,7 @@ def _check_end_conditions(task_id, runtime, endtime, maxretrigger, start_date_or
238242
_log.info(
239243
f"airflow-ha configuration -- endtime: {endtime}, endtime_as_datetime: {endtime_as_datetime}, datetime.now(tz=UTC): {datetime.now(tz=UTC)}"
240244
)
241-
if endtime_as_datetime < datetime.now(tz=UTC):
245+
if endtime_as_datetime <= datetime.now(tz=UTC):
242246
# Endtime has passed, end
243247
_log.info(f"endtime passed for {task_id}, stopping")
244248
return Result.PASS, Action.STOP

airflow_ha/tests/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from airflow.models import DAG
2+
from pytest import fixture
3+
4+
from airflow_ha import Action, HighAvailabilityOperator, Result
5+
6+
7+
@fixture(autouse=True)
8+
def operator():
9+
callable = lambda **kwargs: (Result.PASS, Action.CONTINUE) # noqa: E731
10+
dag = DAG(dag_id="test_dag", default_args={}, schedule=None, params={})
11+
operator = HighAvailabilityOperator(task_id="test_task", python_callable=callable, dag=dag)
12+
return operator

airflow_ha/tests/test_operator.py

+277
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
from datetime import UTC, datetime, timedelta
2+
from unittest.mock import MagicMock, patch
3+
4+
from airflow_ha import Action, HighAvailabilityOperator, Result
5+
from airflow_ha.operator import _callable_wrapper, _choose_branch
6+
7+
8+
class TestHighAvailabilityOperator:
9+
def test_instantiation(self, operator: HighAvailabilityOperator):
10+
assert operator.upstream_list == []
11+
assert operator.downstream_list == [operator.decide_task]
12+
assert operator.retrigger_fail in operator.decide_task.downstream_list
13+
assert operator.retrigger_pass in operator.decide_task.downstream_list
14+
assert operator.stop_pass in operator.decide_task.downstream_list
15+
assert operator.stop_fail in operator.decide_task.downstream_list
16+
assert operator.trigger_rule == "none_failed"
17+
18+
def test_check_end_conditions_default(self, operator: HighAvailabilityOperator):
19+
dag_run_mock = MagicMock()
20+
params_mock = MagicMock()
21+
dag_run_mock.conf.get.side_effect = [
22+
# airflow_ha_force_run
23+
False,
24+
# startdate
25+
datetime(2025, 1, 1, tzinfo=UTC).isoformat(),
26+
# retrigger
27+
0,
28+
]
29+
params_mock.get.side_effect = [
30+
# force-run
31+
False,
32+
# force-runtime
33+
None,
34+
# force-endtime
35+
None,
36+
# maxretrigger
37+
None,
38+
]
39+
assert operator.check_end_conditions(dag_run=dag_run_mock, params=params_mock) is None
40+
41+
def test_check_end_conditions_force(self, operator: HighAvailabilityOperator):
42+
dag_run_mock = MagicMock()
43+
params_mock = MagicMock()
44+
dag_run_mock.conf.get.side_effect = [
45+
# airflow_ha_force_run
46+
True,
47+
# startdate
48+
datetime(2025, 1, 1, tzinfo=UTC).isoformat(),
49+
# retrigger
50+
0,
51+
]
52+
params_mock.get.side_effect = [
53+
# force-run
54+
False,
55+
# force-runtime
56+
None,
57+
# force-endtime
58+
None,
59+
# maxretrigger
60+
None,
61+
]
62+
assert operator.check_end_conditions(dag_run=dag_run_mock, params=params_mock) is None
63+
64+
def test_check_end_conditions_runtime(self, operator: HighAvailabilityOperator):
65+
now = datetime.now(tz=UTC)
66+
yesterday_now = now - timedelta(days=1)
67+
yesterday_now_almost = now - timedelta(hours=23, minutes=59, seconds=59)
68+
69+
dag_run_mock = MagicMock()
70+
params_mock = MagicMock()
71+
dag_run_mock.conf.get.side_effect = [
72+
# airflow_ha_force_run
73+
False,
74+
# startdate
75+
yesterday_now_almost.isoformat(),
76+
# retrigger
77+
0,
78+
# airflow_ha_force_run
79+
False,
80+
# startdate
81+
yesterday_now.isoformat(),
82+
# retrigger
83+
0,
84+
]
85+
params_mock.get.side_effect = [
86+
# force-run
87+
False,
88+
# force-runtime
89+
timedelta(days=1),
90+
# force-endtime
91+
None,
92+
# maxretrigger
93+
None,
94+
# force-run
95+
False,
96+
# force-runtime
97+
timedelta(days=1),
98+
# force-endtime
99+
None,
100+
# maxretrigger
101+
None,
102+
]
103+
assert operator.check_end_conditions(dag_run=dag_run_mock, params=params_mock) is None
104+
assert operator.check_end_conditions(dag_run=dag_run_mock, params=params_mock) == (Result.PASS, Action.STOP)
105+
106+
def test_check_end_conditions_endtime(self, operator: HighAvailabilityOperator):
107+
with patch("airflow_ha.operator.datetime") as mock_datetime:
108+
now = datetime.now(tz=UTC)
109+
one_second_ago = now - timedelta(seconds=1)
110+
mock_datetime.combine = datetime.combine
111+
mock_datetime.fromisoformat = datetime.fromisoformat
112+
mock_datetime.now.return_value = one_second_ago
113+
114+
dag_run_mock = MagicMock()
115+
params_mock = MagicMock()
116+
dag_run_mock.conf.get.side_effect = [
117+
# airflow_ha_force_run
118+
False,
119+
# startdate
120+
one_second_ago.isoformat(),
121+
# retrigger
122+
0,
123+
# airflow_ha_force_run
124+
False,
125+
# startdate
126+
one_second_ago.isoformat(),
127+
# retrigger
128+
0,
129+
]
130+
params_mock.get.side_effect = [
131+
# force-run
132+
False,
133+
# force-runtime
134+
None,
135+
# force-endtime
136+
now.time(),
137+
# maxretrigger
138+
None,
139+
# force-run
140+
False,
141+
# force-runtime
142+
None,
143+
# force-endtime
144+
now.time(),
145+
# maxretrigger
146+
None,
147+
]
148+
assert operator.check_end_conditions(dag=operator.dag, dag_run=dag_run_mock, params=params_mock) is None
149+
mock_datetime.now.return_value = now
150+
assert operator.check_end_conditions(dag=operator.dag, dag_run=dag_run_mock, params=params_mock) == (Result.PASS, Action.STOP)
151+
152+
def test_check_end_conditions_maxretrigger(self, operator: HighAvailabilityOperator):
153+
dag_run_mock = MagicMock()
154+
params_mock = MagicMock()
155+
dag_run_mock.conf.get.side_effect = [
156+
# airflow_ha_force_run
157+
False,
158+
# startdate
159+
None,
160+
# retrigger
161+
0,
162+
# airflow_ha_force_run
163+
False,
164+
# startdate
165+
None,
166+
# retrigger
167+
1,
168+
]
169+
params_mock.get.side_effect = [
170+
# force-run
171+
False,
172+
# force-runtime
173+
None,
174+
# force-endtime
175+
None,
176+
# maxretrigger
177+
1,
178+
# force-run
179+
False,
180+
# force-runtime
181+
None,
182+
# force-endtime
183+
None,
184+
# maxretrigger
185+
1,
186+
]
187+
assert operator.check_end_conditions(dag_run=dag_run_mock, params=params_mock) is None
188+
assert operator.check_end_conditions(dag_run=dag_run_mock, params=params_mock) == (None, Action.STOP)
189+
190+
def test_choose_branch(self):
191+
task_instance_mock = MagicMock()
192+
branch_choices = {
193+
(Result.PASS, Action.RETRIGGER): "a",
194+
(Result.PASS, Action.STOP): "b",
195+
(Result.FAIL, Action.RETRIGGER): "c",
196+
(Result.FAIL, Action.STOP): "d",
197+
}
198+
199+
# default
200+
res = _choose_branch(
201+
task_instance=task_instance_mock, branch_choices=branch_choices, task_id="test_task", check_end_conditions=lambda **kwargs: None
202+
)
203+
assert res == "a"
204+
205+
# pass, stop
206+
task_instance_mock.xcom_pull.return_value = (Result.PASS, Action.STOP)
207+
res = _choose_branch(
208+
task_instance=task_instance_mock, branch_choices=branch_choices, task_id="test_task", check_end_conditions=lambda **kwargs: None
209+
)
210+
assert res == "b"
211+
task_instance_mock.xcom_pull.return_value = (Result.FAIL, Action.RETRIGGER)
212+
res = _choose_branch(
213+
task_instance=task_instance_mock, branch_choices=branch_choices, task_id="test_task", check_end_conditions=lambda **kwargs: None
214+
)
215+
assert res == "c"
216+
task_instance_mock.xcom_pull.return_value = (Result.FAIL, Action.STOP)
217+
res = _choose_branch(
218+
task_instance=task_instance_mock, branch_choices=branch_choices, task_id="test_task", check_end_conditions=lambda **kwargs: None
219+
)
220+
assert res == "d"
221+
222+
# end conditions met
223+
res = _choose_branch(
224+
task_instance=task_instance_mock,
225+
branch_choices=branch_choices,
226+
task_id="test_task",
227+
check_end_conditions=lambda **kwargs: (Result.PASS, Action.STOP),
228+
)
229+
assert res == "b"
230+
231+
# retrigger exceeded
232+
res = _choose_branch(
233+
task_instance=task_instance_mock,
234+
branch_choices=branch_choices,
235+
task_id="test_task",
236+
check_end_conditions=lambda **kwargs: (None, Action.STOP),
237+
)
238+
assert res == "b"
239+
240+
def test_callable_wrapper(self):
241+
task_instance_mock = MagicMock()
242+
243+
# malformed
244+
res = _callable_wrapper(
245+
task_instance=task_instance_mock, python_callable=lambda **kwargs: (None, Action.CONTINUE), check_end_conditions=lambda **kwargs: None
246+
)
247+
assert res is True
248+
249+
# pass and continue
250+
res = _callable_wrapper(
251+
task_instance=task_instance_mock,
252+
python_callable=lambda **kwargs: (Result.PASS, Action.CONTINUE),
253+
check_end_conditions=lambda **kwargs: None,
254+
)
255+
assert res is False
256+
257+
# pass and retrigger (keep testing)
258+
res = _callable_wrapper(
259+
task_instance=task_instance_mock,
260+
python_callable=lambda **kwargs: (Result.PASS, Action.RETRIGGER),
261+
check_end_conditions=lambda **kwargs: None,
262+
)
263+
assert res is True
264+
265+
# pass and stop
266+
res = _callable_wrapper(
267+
task_instance=task_instance_mock, python_callable=lambda **kwargs: (Result.PASS, Action.STOP), check_end_conditions=lambda **kwargs: None
268+
)
269+
assert res is True
270+
271+
# end conditions met
272+
res = _callable_wrapper(
273+
task_instance=task_instance_mock,
274+
python_callable=lambda **kwargs: (None, Action.CONTINUE),
275+
check_end_conditions=lambda **kwargs: (None, Action.STOP),
276+
)
277+
assert res is True

airflow_ha/utils.py

-19
This file was deleted.

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ classifiers = [
3030

3131
dependencies = [
3232
"apache-airflow>=2.8,<3",
33+
"airflow-common-operators>=0.2.0,<0.3",
3334
]
3435

3536
[project.optional-dependencies]
@@ -85,7 +86,7 @@ exclude_also = [
8586
"@(abc\\.)?abstractmethod",
8687
]
8788
ignore_errors = true
88-
fail_under = 25
89+
fail_under = 90
8990

9091
[tool.hatch.build]
9192
artifacts = []

0 commit comments

Comments
 (0)