Skip to content

Commit 855ef94

Browse files
kitagryCopilot
andauthored
fix: gcs_object_metadata_client can't handle required_tasks (#467)
* fix: required_tasks error * fix: delete unnecessary function Co-authored-by: Copilot <[email protected]> * test: nested required_task output --------- Co-authored-by: Copilot <[email protected]>
1 parent 0b457b3 commit 855ef94

File tree

2 files changed

+37
-13
lines changed

2 files changed

+37
-13
lines changed

gokart/gcs_obj_metadata_client.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def _get_patched_obj_metadata(
9898
if not isinstance(metadata, dict):
9999
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
100100
return metadata
101-
if not task_params and not custom_labels:
102-
return metadata
103101
# Maximum size of metadata for each object is 8 KiB.
104102
# [Link]: https://cloud.google.com/storage/quotas#objects
105103
normalized_task_params_labels = GCSObjectMetadataClient._normalize_labels(task_params)
@@ -117,18 +115,17 @@ def _get_patched_obj_metadata(
117115

118116
@staticmethod
119117
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
120-
def _iterable_flatten(nested_list: Iterable) -> Iterable[str]:
121-
for item in nested_list:
122-
if isinstance(item, Iterable):
123-
yield from _iterable_flatten(item)
124-
else:
125-
yield item
126-
127-
if isinstance(required_task_outputs, dict):
118+
if isinstance(required_task_outputs, RequiredTaskOutput):
119+
return required_task_outputs.serialize()
120+
elif isinstance(required_task_outputs, dict):
128121
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
129-
if isinstance(required_task_outputs, Iterable):
130-
return list(_iterable_flatten([GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]))
131-
return [required_task_outputs.serialize()]
122+
elif isinstance(required_task_outputs, Iterable):
123+
return [GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]
124+
else:
125+
raise TypeError(
126+
f'Unsupported type for required_task_outputs: {type(required_task_outputs)}. '
127+
'It should be RequiredTaskOutput, dict, or iterable of RequiredTaskOutput.'
128+
)
132129

133130
@staticmethod
134131
def _merge_custom_labels_and_task_params_labels(

test/test_gcs_obj_metadata_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import gokart
99
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
10+
from gokart.required_task_output import RequiredTaskOutput
1011
from gokart.target import TargetOnKart
1112

1213

@@ -113,6 +114,32 @@ def test_get_patched_obj_metadata_with_conflicts(self):
113114
self.assertEqual(got['created_by'], 'hoge fuga')
114115
self.assertEqual(got['param1'], 'a' * 10)
115116

117+
def test_get_patched_obj_metadata_with_required_task_outputs(self):
118+
got = GCSObjectMetadataClient._get_patched_obj_metadata(
119+
{},
120+
required_task_outputs=[
121+
RequiredTaskOutput(task_name='task1', output_path='path/to/output1'),
122+
],
123+
)
124+
125+
self.assertIsInstance(got, dict)
126+
self.assertIn('__required_task_outputs', got)
127+
self.assertEqual(got['__required_task_outputs'], '[{"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}]')
128+
129+
def test_get_patched_obj_metadata_with_nested_required_task_outputs(self):
130+
got = GCSObjectMetadataClient._get_patched_obj_metadata(
131+
{},
132+
required_task_outputs={
133+
'nested_task': {'nest': RequiredTaskOutput(task_name='task1', output_path='path/to/output1')},
134+
},
135+
)
136+
137+
self.assertIsInstance(got, dict)
138+
self.assertIn('__required_task_outputs', got)
139+
self.assertEqual(
140+
got['__required_task_outputs'], '{"nested_task": {"nest": {"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}}}'
141+
)
142+
116143

117144
class TestGokartTask(unittest.TestCase):
118145
@patch.object(_DummyTaskOnKart, '_get_output_target')

0 commit comments

Comments
 (0)