Skip to content

Commit 236ac38

Browse files
kmontetfx-copybara
authored andcommitted
Fix issue where subpipelines may get stuck due to insufficient task schedulers by raising an error when the total number of subpipelines is greater than the maximum allowable task schedulers.
PiperOrigin-RevId: 660525484
1 parent c695752 commit 236ac38

File tree

2 files changed

+105
-13
lines changed

2 files changed

+105
-13
lines changed

tfx/orchestration/experimental/core/pipeline_state.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Pipeline state management functionality."""
1515

1616
import base64
17+
import collections
1718
import contextlib
1819
import copy
1920
import dataclasses
@@ -515,6 +516,34 @@ def new(
515516
Raises:
516517
status_lib.StatusNotOkError: If a pipeline with same UID already exists.
517518
"""
519+
num_subpipelines = 0
520+
to_process = collections.deque([pipeline])
521+
while to_process:
522+
p = to_process.popleft()
523+
for node in p.nodes:
524+
if node.WhichOneof('node') == 'sub_pipeline':
525+
num_subpipelines += 1
526+
to_process.append(node.sub_pipeline)
527+
# If the number of active task schedulers is less than the maximum number of
528+
# active task schedulers, subpipelines may not work.
529+
# This is because when scheduling the subpipeline, the start node
530+
# and end node will be scheduled immediately, potentially causing contention
531+
# where the end node is waiting on some intermediary node to finish, but the
532+
# intermediary node cannot be scheduled as the end node is running.
533+
# Note that this number is an overestimate - in reality if subpipelines are
534+
# dependent on each other we may not need so many task schedulers.
535+
max_task_schedulers = env.get_env().maximum_active_task_schedulers()
536+
if max_task_schedulers < num_subpipelines:
537+
raise status_lib.StatusNotOkError(
538+
code=status_lib.Code.FAILED_PRECONDITION,
539+
message=(
540+
f'The maximum number of task schedulers ({max_task_schedulers})'
541+
f' is less than the number of subpipelines ({num_subpipelines}).'
542+
' Please set the maximum number of task schedulers to at least'
543+
f' {num_subpipelines} in'
544+
' OrchestrationOptions.max_running_components.'
545+
),
546+
)
518547
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
519548
context = context_lib.register_context_if_not_exists(
520549
mlmd_handle,

tfx/orchestration/experimental/core/pipeline_state_test.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
import dataclasses
1717
import os
18+
import sys
1819
import time
19-
from typing import List
20+
from typing import List, Optional
2021
from unittest import mock
2122

2223
from absl.testing import parameterized
@@ -36,6 +37,7 @@
3637
from tfx.proto.orchestration import run_state_pb2
3738
from tfx.utils import json_utils
3839
from tfx.utils import status as status_lib
40+
3941
import ml_metadata as mlmd
4042
from ml_metadata.proto import metadata_store_pb2
4143

@@ -155,9 +157,19 @@ def test_node_state_json(self):
155157

156158
class TestEnv(env._DefaultEnv):
157159

158-
def __init__(self, base_dir, max_str_len):
160+
def __init__(
161+
self,
162+
*,
163+
base_dir: Optional[str],
164+
max_str_len: int,
165+
max_task_schedulers: int
166+
):
159167
self.base_dir = base_dir
160168
self.max_str_len = max_str_len
169+
self.max_task_schedulers = max_task_schedulers
170+
171+
def maximum_active_task_schedulers(self) -> int:
172+
return self.max_task_schedulers
161173

162174
def get_base_dir(self):
163175
return self.base_dir
@@ -216,7 +228,9 @@ def test_new_pipeline_state(self):
216228
self.assertTrue(pstate._active_owned_pipelines_exist)
217229

218230
def test_new_pipeline_state_with_sub_pipelines(self):
219-
with self._mlmd_connection as m:
231+
with TestEnv(
232+
base_dir=None, max_str_len=20000, max_task_schedulers=2
233+
), self._mlmd_connection as m:
220234
pstate._active_owned_pipelines_exist = False
221235
pipeline = _test_pipeline('pipeline1')
222236
# Add 2 additional layers of sub pipelines. Note that there is no normal
@@ -276,6 +290,35 @@ def test_new_pipeline_state_with_sub_pipelines(self):
276290
],
277291
)
278292

293+
def test_new_pipeline_state_with_sub_pipelines_fails_when_not_enough_task_schedulers(
294+
self,
295+
):
296+
with TestEnv(
297+
base_dir=None, max_str_len=20000, max_task_schedulers=1
298+
), self._mlmd_connection as m:
299+
pstate._active_owned_pipelines_exist = False
300+
pipeline = _test_pipeline('pipeline1')
301+
# Add 2 additional layers of sub pipelines. Note that there is no normal
302+
# pipeline node in the first pipeline layer.
303+
_add_sub_pipeline(
304+
pipeline,
305+
'sub_pipeline1',
306+
sub_pipeline_nodes=['Trainer'],
307+
sub_pipeline_run_id='sub_pipeline1_run0',
308+
)
309+
_add_sub_pipeline(
310+
pipeline.nodes[0].sub_pipeline,
311+
'sub_pipeline2',
312+
sub_pipeline_nodes=['Trainer'],
313+
sub_pipeline_run_id='sub_pipeline1_sub_pipeline2_run0',
314+
)
315+
with self.assertRaisesRegex(
316+
status_lib.StatusNotOkError,
317+
'The maximum number of task schedulers',
318+
) as e:
319+
pstate.PipelineState.new(m, pipeline)
320+
self.assertEqual(e.exception.code, status_lib.Code.FAILED_PRECONDITION)
321+
279322
def test_load_pipeline_state(self):
280323
with self._mlmd_connection as m:
281324
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
@@ -770,7 +813,9 @@ def test_initiate_node_start_stop(self, mock_time):
770813
def recorder(event):
771814
events.append(event)
772815

773-
with TestEnv(None, 2000), event_observer.init(), self._mlmd_connection as m:
816+
with TestEnv(
817+
base_dir=None, max_str_len=2000, max_task_schedulers=sys.maxsize
818+
), event_observer.init(), self._mlmd_connection as m:
774819
event_observer.register_observer(recorder)
775820

776821
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
@@ -900,7 +945,9 @@ def recorder(event):
900945
@mock.patch.object(pstate, 'time')
901946
def test_get_node_states_dict(self, mock_time):
902947
mock_time.time.return_value = time.time()
903-
with TestEnv(None, 20000), self._mlmd_connection as m:
948+
with TestEnv(
949+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
950+
), self._mlmd_connection as m:
904951
pipeline = _test_pipeline(
905952
'pipeline1',
906953
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1120,7 +1167,9 @@ def test_pipeline_view_get_pipeline_run_state(self, mock_time):
11201167
@mock.patch.object(pstate, 'time')
11211168
def test_pipeline_view_get_node_run_states(self, mock_time):
11221169
mock_time.time.return_value = time.time()
1123-
with TestEnv(None, 20000), self._mlmd_connection as m:
1170+
with TestEnv(
1171+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
1172+
), self._mlmd_connection as m:
11241173
pipeline = _test_pipeline(
11251174
'pipeline1',
11261175
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1205,7 +1254,9 @@ def test_pipeline_view_get_node_run_states(self, mock_time):
12051254
@mock.patch.object(pstate, 'time')
12061255
def test_pipeline_view_get_node_run_state_history(self, mock_time):
12071256
mock_time.time.return_value = time.time()
1208-
with TestEnv(None, 20000), self._mlmd_connection as m:
1257+
with TestEnv(
1258+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
1259+
), self._mlmd_connection as m:
12091260
pipeline = _test_pipeline(
12101261
'pipeline1',
12111262
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1252,7 +1303,9 @@ def test_node_state_for_skipped_nodes_in_partial_pipeline_run(
12521303
):
12531304
"""Tests that nodes marked to be skipped have the right node state and previous node state."""
12541305
mock_time.time.return_value = time.time()
1255-
with TestEnv(None, 20000), self._mlmd_connection as m:
1306+
with TestEnv(
1307+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
1308+
), self._mlmd_connection as m:
12561309
pipeline = _test_pipeline(
12571310
'pipeline1',
12581311
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1371,7 +1424,9 @@ def test_load_all_with_list_options(self):
13711424
def test_get_previous_node_run_states_for_skipped_nodes(self, mock_time):
13721425
"""Tests that nodes marked to be skipped have the right previous run state."""
13731426
mock_time.time.return_value = time.time()
1374-
with TestEnv(None, 20000), self._mlmd_connection as m:
1427+
with TestEnv(
1428+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
1429+
), self._mlmd_connection as m:
13751430
pipeline = _test_pipeline(
13761431
'pipeline1',
13771432
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1498,7 +1553,9 @@ def test_create_and_load_concurrent_pipeline_runs(self):
14981553
)
14991554

15001555
def test_get_pipeline_and_node(self):
1501-
with TestEnv(None, 20000), self._mlmd_connection as m:
1556+
with TestEnv(
1557+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
1558+
), self._mlmd_connection as m:
15021559
pipeline = _test_pipeline(
15031560
'pipeline1',
15041561
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1516,7 +1573,9 @@ def test_get_pipeline_and_node(self):
15161573
)
15171574

15181575
def test_get_pipeline_and_node_not_found(self):
1519-
with TestEnv(None, 20000), self._mlmd_connection as m:
1576+
with TestEnv(
1577+
base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize
1578+
), self._mlmd_connection as m:
15201579
pipeline = _test_pipeline(
15211580
'pipeline1',
15221581
execution_mode=pipeline_pb2.Pipeline.SYNC,
@@ -1594,7 +1653,9 @@ def test_save_with_max_str_len(self):
15941653
state=pstate.NodeState.COMPLETE,
15951654
)
15961655
}
1597-
with TestEnv(None, 20):
1656+
with TestEnv(
1657+
base_dir=None, max_str_len=20, max_task_schedulers=sys.maxsize
1658+
):
15981659
execution = metadata_store_pb2.Execution()
15991660
proxy = pstate._NodeStatesProxy(execution)
16001661
proxy.set(node_states)
@@ -1605,7 +1666,9 @@ def test_save_with_max_str_len(self):
16051666
),
16061667
json_utils.dumps(node_states_without_state_history),
16071668
)
1608-
with TestEnv(None, 2000):
1669+
with TestEnv(
1670+
base_dir=None, max_str_len=2000, max_task_schedulers=sys.maxsize
1671+
):
16091672
execution = metadata_store_pb2.Execution()
16101673
proxy = pstate._NodeStatesProxy(execution)
16111674
proxy.set(node_states)

0 commit comments

Comments
 (0)