15
15
16
16
import dataclasses
17
17
import os
18
+ import sys
18
19
import time
19
- from typing import List
20
+ from typing import List , Optional
20
21
from unittest import mock
21
22
22
23
from absl .testing import parameterized
36
37
from tfx .proto .orchestration import run_state_pb2
37
38
from tfx .utils import json_utils
38
39
from tfx .utils import status as status_lib
40
+
39
41
import ml_metadata as mlmd
40
42
from ml_metadata .proto import metadata_store_pb2
41
43
@@ -155,9 +157,19 @@ def test_node_state_json(self):
155
157
156
158
class TestEnv (env ._DefaultEnv ):
157
159
158
- def __init__ (self , base_dir , max_str_len ):
160
+ def __init__ (
161
+ self ,
162
+ * ,
163
+ base_dir : Optional [str ],
164
+ max_str_len : int ,
165
+ max_task_schedulers : int
166
+ ):
159
167
self .base_dir = base_dir
160
168
self .max_str_len = max_str_len
169
+ self .max_task_schedulers = max_task_schedulers
170
+
171
+ def maximum_active_task_schedulers (self ) -> int :
172
+ return self .max_task_schedulers
161
173
162
174
def get_base_dir (self ):
163
175
return self .base_dir
@@ -216,7 +228,9 @@ def test_new_pipeline_state(self):
216
228
self .assertTrue (pstate ._active_owned_pipelines_exist )
217
229
218
230
def test_new_pipeline_state_with_sub_pipelines (self ):
219
- with self ._mlmd_connection as m :
231
+ with TestEnv (
232
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = 2
233
+ ), self ._mlmd_connection as m :
220
234
pstate ._active_owned_pipelines_exist = False
221
235
pipeline = _test_pipeline ('pipeline1' )
222
236
# Add 2 additional layers of sub pipelines. Note that there is no normal
@@ -276,6 +290,35 @@ def test_new_pipeline_state_with_sub_pipelines(self):
276
290
],
277
291
)
278
292
293
+ def test_new_pipeline_state_with_sub_pipelines_fails_when_not_enough_task_schedulers (
294
+ self ,
295
+ ):
296
+ with TestEnv (
297
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = 1
298
+ ), self ._mlmd_connection as m :
299
+ pstate ._active_owned_pipelines_exist = False
300
+ pipeline = _test_pipeline ('pipeline1' )
301
+ # Add 2 additional layers of sub pipelines. Note that there is no normal
302
+ # pipeline node in the first pipeline layer.
303
+ _add_sub_pipeline (
304
+ pipeline ,
305
+ 'sub_pipeline1' ,
306
+ sub_pipeline_nodes = ['Trainer' ],
307
+ sub_pipeline_run_id = 'sub_pipeline1_run0' ,
308
+ )
309
+ _add_sub_pipeline (
310
+ pipeline .nodes [0 ].sub_pipeline ,
311
+ 'sub_pipeline2' ,
312
+ sub_pipeline_nodes = ['Trainer' ],
313
+ sub_pipeline_run_id = 'sub_pipeline1_sub_pipeline2_run0' ,
314
+ )
315
+ with self .assertRaisesRegex (
316
+ status_lib .StatusNotOkError ,
317
+ 'The maximum number of task schedulers' ,
318
+ ) as e :
319
+ pstate .PipelineState .new (m , pipeline )
320
+ self .assertEqual (e .exception .code , status_lib .Code .FAILED_PRECONDITION )
321
+
279
322
def test_load_pipeline_state (self ):
280
323
with self ._mlmd_connection as m :
281
324
pipeline = _test_pipeline ('pipeline1' , pipeline_nodes = ['Trainer' ])
@@ -770,7 +813,9 @@ def test_initiate_node_start_stop(self, mock_time):
770
813
def recorder (event ):
771
814
events .append (event )
772
815
773
- with TestEnv (None , 2000 ), event_observer .init (), self ._mlmd_connection as m :
816
+ with TestEnv (
817
+ base_dir = None , max_str_len = 2000 , max_task_schedulers = sys .maxsize
818
+ ), event_observer .init (), self ._mlmd_connection as m :
774
819
event_observer .register_observer (recorder )
775
820
776
821
pipeline = _test_pipeline ('pipeline1' , pipeline_nodes = ['Trainer' ])
@@ -900,7 +945,9 @@ def recorder(event):
900
945
@mock .patch .object (pstate , 'time' )
901
946
def test_get_node_states_dict (self , mock_time ):
902
947
mock_time .time .return_value = time .time ()
903
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
948
+ with TestEnv (
949
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
950
+ ), self ._mlmd_connection as m :
904
951
pipeline = _test_pipeline (
905
952
'pipeline1' ,
906
953
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1120,7 +1167,9 @@ def test_pipeline_view_get_pipeline_run_state(self, mock_time):
1120
1167
@mock .patch .object (pstate , 'time' )
1121
1168
def test_pipeline_view_get_node_run_states (self , mock_time ):
1122
1169
mock_time .time .return_value = time .time ()
1123
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
1170
+ with TestEnv (
1171
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
1172
+ ), self ._mlmd_connection as m :
1124
1173
pipeline = _test_pipeline (
1125
1174
'pipeline1' ,
1126
1175
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1205,7 +1254,9 @@ def test_pipeline_view_get_node_run_states(self, mock_time):
1205
1254
@mock .patch .object (pstate , 'time' )
1206
1255
def test_pipeline_view_get_node_run_state_history (self , mock_time ):
1207
1256
mock_time .time .return_value = time .time ()
1208
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
1257
+ with TestEnv (
1258
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
1259
+ ), self ._mlmd_connection as m :
1209
1260
pipeline = _test_pipeline (
1210
1261
'pipeline1' ,
1211
1262
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1252,7 +1303,9 @@ def test_node_state_for_skipped_nodes_in_partial_pipeline_run(
1252
1303
):
1253
1304
"""Tests that nodes marked to be skipped have the right node state and previous node state."""
1254
1305
mock_time .time .return_value = time .time ()
1255
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
1306
+ with TestEnv (
1307
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
1308
+ ), self ._mlmd_connection as m :
1256
1309
pipeline = _test_pipeline (
1257
1310
'pipeline1' ,
1258
1311
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1371,7 +1424,9 @@ def test_load_all_with_list_options(self):
1371
1424
def test_get_previous_node_run_states_for_skipped_nodes (self , mock_time ):
1372
1425
"""Tests that nodes marked to be skipped have the right previous run state."""
1373
1426
mock_time .time .return_value = time .time ()
1374
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
1427
+ with TestEnv (
1428
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
1429
+ ), self ._mlmd_connection as m :
1375
1430
pipeline = _test_pipeline (
1376
1431
'pipeline1' ,
1377
1432
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1498,7 +1553,9 @@ def test_create_and_load_concurrent_pipeline_runs(self):
1498
1553
)
1499
1554
1500
1555
def test_get_pipeline_and_node (self ):
1501
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
1556
+ with TestEnv (
1557
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
1558
+ ), self ._mlmd_connection as m :
1502
1559
pipeline = _test_pipeline (
1503
1560
'pipeline1' ,
1504
1561
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1516,7 +1573,9 @@ def test_get_pipeline_and_node(self):
1516
1573
)
1517
1574
1518
1575
def test_get_pipeline_and_node_not_found (self ):
1519
- with TestEnv (None , 20000 ), self ._mlmd_connection as m :
1576
+ with TestEnv (
1577
+ base_dir = None , max_str_len = 20000 , max_task_schedulers = sys .maxsize
1578
+ ), self ._mlmd_connection as m :
1520
1579
pipeline = _test_pipeline (
1521
1580
'pipeline1' ,
1522
1581
execution_mode = pipeline_pb2 .Pipeline .SYNC ,
@@ -1594,7 +1653,9 @@ def test_save_with_max_str_len(self):
1594
1653
state = pstate .NodeState .COMPLETE ,
1595
1654
)
1596
1655
}
1597
- with TestEnv (None , 20 ):
1656
+ with TestEnv (
1657
+ base_dir = None , max_str_len = 20 , max_task_schedulers = sys .maxsize
1658
+ ):
1598
1659
execution = metadata_store_pb2 .Execution ()
1599
1660
proxy = pstate ._NodeStatesProxy (execution )
1600
1661
proxy .set (node_states )
@@ -1605,7 +1666,9 @@ def test_save_with_max_str_len(self):
1605
1666
),
1606
1667
json_utils .dumps (node_states_without_state_history ),
1607
1668
)
1608
- with TestEnv (None , 2000 ):
1669
+ with TestEnv (
1670
+ base_dir = None , max_str_len = 2000 , max_task_schedulers = sys .maxsize
1671
+ ):
1609
1672
execution = metadata_store_pb2 .Execution ()
1610
1673
proxy = pstate ._NodeStatesProxy (execution )
1611
1674
proxy .set (node_states )
0 commit comments