Skip to content

Commit d3eee04

Browse files
authored
feat: make worker skip running task when it is completed (#406)
* feat: add gokart_worker configurations as same as luigi one * feat: make worker skip run when a task is completed
1 parent 29b29e3 commit d3eee04

File tree

2 files changed

+133
-6
lines changed

2 files changed

+133
-6
lines changed

gokart/worker.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@
6161
from luigi.task_register import TaskClassException, load_task
6262
from luigi.task_status import RUNNING
6363

64-
logger = logging.getLogger('luigi-interface')
64+
from gokart.parameter import ExplicitBoolParameter
65+
66+
logger = logging.getLogger(__name__)
6567

6668
# Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex,
6769
# that may not be unlocked in child process, resulting in the process being locked indefinitely.
@@ -124,6 +126,7 @@ def __init__(
124126
check_unfulfilled_deps: bool = True,
125127
check_complete_on_run: bool = False,
126128
task_completion_cache: Optional[Dict[str, Any]] = None,
129+
skip_if_completed_pre_run: bool = True,
127130
) -> None:
128131
super(TaskProcess, self).__init__()
129132
self.task = task
@@ -136,12 +139,19 @@ def __init__(
136139
self.check_unfulfilled_deps = check_unfulfilled_deps
137140
self.check_complete_on_run = check_complete_on_run
138141
self.task_completion_cache = task_completion_cache
142+
self.skip_if_completed_pre_run = skip_if_completed_pre_run
139143

140144
# completeness check using the cache
141145
self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache)
142146

147+
def _run_task(self) -> Optional[collections.abc.Generator]:
148+
if self.skip_if_completed_pre_run and self.check_complete(self.task):
149+
logger.warning(f'{self.task} is skipped because the task is already completed.')
150+
return None
151+
return self.task.run()
152+
143153
def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]:
144-
task_gen = self.task.run()
154+
task_gen = self._run_task()
145155

146156
if not isinstance(task_gen, collections.abc.Generator):
147157
return None
@@ -308,6 +318,68 @@ def run(self) -> None:
308318
super(ContextManagedTaskProcess, self).run()
309319

310320

321+
class gokart_worker(luigi.Config):
322+
"""Configuration for the gokart worker.
323+
324+
You can set these options of section [gokart_worker] in your luigi.cfg file.
325+
326+
NOTE: use snake_case for this class to match the luigi.Config convention.
327+
"""
328+
329+
id = luigi.Parameter(default='', description='Override the auto-generated worker_id')
330+
ping_interval = luigi.FloatParameter(default=1.0, config_path=dict(section='core', name='worker-ping-interval'))
331+
keep_alive = luigi.BoolParameter(default=False, config_path=dict(section='core', name='worker-keep-alive'))
332+
count_uniques = luigi.BoolParameter(
333+
default=False,
334+
config_path=dict(section='core', name='worker-count-uniques'),
335+
description='worker-count-uniques means that we will keep a ' 'worker alive only if it has a unique pending task, as ' 'well as having keep-alive true',
336+
)
337+
count_last_scheduled = luigi.BoolParameter(
338+
default=False, description='Keep a worker alive only if there are ' 'pending tasks which it was the last to ' 'schedule.'
339+
)
340+
wait_interval = luigi.FloatParameter(default=1.0, config_path=dict(section='core', name='worker-wait-interval'))
341+
wait_jitter = luigi.FloatParameter(default=5.0)
342+
343+
max_keep_alive_idle_duration = luigi.TimeDeltaParameter(default=datetime.timedelta(0))
344+
345+
max_reschedules = luigi.IntParameter(default=1, config_path=dict(section='core', name='worker-max-reschedules'))
346+
timeout = luigi.IntParameter(default=0, config_path=dict(section='core', name='worker-timeout'))
347+
task_limit = luigi.IntParameter(default=None, config_path=dict(section='core', name='worker-task-limit'))
348+
retry_external_tasks = luigi.BoolParameter(
349+
default=False,
350+
config_path=dict(section='core', name='retry-external-tasks'),
351+
description='If true, incomplete external tasks will be ' 'retested for completion while Luigi is running.',
352+
)
353+
send_failure_email = luigi.BoolParameter(default=True, description='If true, send e-mails directly from the worker' 'on failure')
354+
no_install_shutdown_handler = luigi.BoolParameter(default=False, description='If true, the SIGUSR1 shutdown handler will' 'NOT be install on the worker')
355+
check_unfulfilled_deps = luigi.BoolParameter(default=True, description='If true, check for completeness of ' 'dependencies before running a task')
356+
check_complete_on_run = luigi.BoolParameter(
357+
default=False,
358+
description='If true, only mark tasks as done after running if they are complete. '
359+
'Regardless of this setting, the worker will always check if external '
360+
'tasks are complete before marking them as done.',
361+
)
362+
force_multiprocessing = luigi.BoolParameter(default=False, description='If true, use multiprocessing also when ' 'running with 1 worker')
363+
task_process_context = luigi.OptionalParameter(
364+
default=None,
365+
description='If set to a fully qualified class name, the class will '
366+
'be instantiated with a TaskProcess as its constructor parameter and '
367+
'applied as a context manager around its run() call, so this can be '
368+
'used for obtaining high level customizable monitoring or logging of '
369+
'each individual Task run.',
370+
)
371+
cache_task_completion = luigi.BoolParameter(
372+
default=False,
373+
description='If true, cache the response of successful completion checks '
374+
'of tasks assigned to a worker. This can especially speed up tasks with '
375+
'dynamic dependencies but assumes that the completion status does not change '
376+
'after it was true the first time.',
377+
)
378+
skip_if_completed_pre_run: bool = ExplicitBoolParameter(
379+
default=True, description='If true, skip running tasks that are already completed just before the Task is run.'
380+
)
381+
382+
311383
class Worker:
312384
"""
313385
Worker object communicates with a scheduler.
@@ -319,15 +391,22 @@ class Worker:
319391
"""
320392

321393
def __init__(
322-
self, scheduler: Optional[Scheduler] = None, worker_id: Optional[str] = None, worker_processes: int = 1, assistant: bool = False, **kwargs: Any
394+
self,
395+
scheduler: Optional[Scheduler] = None,
396+
worker_id: Optional[str] = None,
397+
worker_processes: int = 1,
398+
assistant: bool = False,
399+
config: Optional[gokart_worker] = None,
323400
) -> None:
324401
if scheduler is None:
325402
scheduler = Scheduler()
326403

327404
self.worker_processes = int(worker_processes)
328405
self._worker_info = self._generate_worker_info()
329-
330-
self._config = luigi.worker.worker(**kwargs)
406+
if config is None:
407+
self._config = gokart_worker()
408+
else:
409+
self._config = config
331410

332411
worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info)
333412

@@ -836,6 +915,7 @@ def _create_task_process(self, task):
836915
check_unfulfilled_deps=self._config.check_unfulfilled_deps,
837916
check_complete_on_run=self._config.check_complete_on_run,
838917
task_completion_cache=self._task_completion_cache,
918+
skip_if_completed_pre_run=self._config.skip_if_completed_pre_run,
839919
)
840920

841921
def _purge_children(self) -> None:

test/test_worker.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from unittest.mock import Mock
33

44
import luigi
5+
import luigi.worker
56
import pytest
67
from luigi import scheduler
78

89
import gokart
9-
from gokart.worker import Worker
10+
from gokart.worker import Worker, gokart_worker
1011

1112

1213
class _DummyTask(gokart.TaskOnKart):
@@ -33,3 +34,49 @@ def test_run(self, monkeypatch: pytest.MonkeyPatch):
3334
assert worker.add(task)
3435
assert worker.run()
3536
mock_run.assert_called_once()
37+
38+
39+
class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]):
40+
task_namespace = __name__
41+
42+
def _run(self): ...
43+
44+
def run(self):
45+
self._run()
46+
self.dump(None)
47+
48+
def complete(self) -> bool:
49+
return False
50+
51+
52+
class TestWorkerSkipIfCompletedPreRun:
53+
@pytest.mark.parametrize(
54+
'skip_if_completed_pre_run,is_completed,expect_skipped',
55+
[
56+
pytest.param(True, True, True, id='skipped when completed and skip_if_completed_pre_run is True'),
57+
pytest.param(True, False, False, id='not skipped when not completed and skip_if_completed_pre_run is True'),
58+
pytest.param(False, True, False, id='not skipped when completed and skip_if_completed_pre_run is False'),
59+
pytest.param(False, False, False, id='not skipped when not completed and skip_if_completed_pre_run is False'),
60+
],
61+
)
62+
def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, skip_if_completed_pre_run: bool, is_completed: bool, expect_skipped: bool):
63+
sch = scheduler.Scheduler()
64+
worker = Worker(scheduler=sch, config=gokart_worker(skip_if_completed_pre_run=skip_if_completed_pre_run))
65+
66+
mock_complete = Mock(return_value=is_completed)
67+
# NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic.
68+
task = _DummyTaskToCheckSkip(complete_check_at_run=False)
69+
mock_run = Mock()
70+
monkeypatch.setattr(task, '_run', mock_run)
71+
72+
with worker:
73+
assert worker.add(task)
74+
# NOTE: mock `complete` after `add` because `add` calls `complete`
75+
# to check if the task is already completed.
76+
monkeypatch.setattr(task, 'complete', mock_complete)
77+
assert worker.run()
78+
79+
if expect_skipped:
80+
mock_run.assert_not_called()
81+
else:
82+
mock_run.assert_called_once()

0 commit comments

Comments
 (0)