Skip to content

Commit 59eced7

Browse files
authored
[CAT-102] pr-ing missed commit for bytestream (#119)
1 parent ed4dfcb commit 59eced7

File tree

5 files changed

+72
-31
lines changed

5 files changed

+72
-31
lines changed

indico/http/client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def _handle_files(self, req_kwargs):
8787
# deepcopying buffers is not supported
8888
# so, remove "streams" before the deepcopy.
8989
if "streams" in req_kwargs:
90-
streams = req_kwargs["streams"].copy()
90+
if req_kwargs["streams"] is not None:
91+
streams = req_kwargs["streams"].copy()
9192
del req_kwargs["streams"]
9293

9394
new_kwargs = deepcopy(req_kwargs)

indico/queries/datasets.py

+1
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ class _AddFiles(GraphQLRequest):
319319
}
320320
}
321321
"""
322+
322323

323324
def __init__(self, dataset_id: int, metadata: List[str]):
324325
super().__init__(

indico/queries/workflow.py

+33-29
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ListWorkflows(GraphQLRequest):
3535
"""
3636

3737
def __init__(
38-
self, *, dataset_ids: List[int] = None, workflow_ids: List[int] = None
38+
self, *, dataset_ids: List[int] = None, workflow_ids: List[int] = None
3939
):
4040
super().__init__(
4141
self.query,
@@ -114,10 +114,10 @@ class UpdateWorkflowSettings(RequestChain):
114114
"""
115115

116116
def __init__(
117-
self,
118-
workflow: Union[Workflow, int],
119-
enable_review: bool = None,
120-
enable_auto_review: bool = None,
117+
self,
118+
workflow: Union[Workflow, int],
119+
enable_review: bool = None,
120+
enable_auto_review: bool = None,
121121
):
122122
self.workflow_id = workflow.id if isinstance(workflow, Workflow) else workflow
123123
if enable_review is None and enable_auto_review is None:
@@ -134,7 +134,6 @@ def requests(self):
134134

135135

136136
class _WorkflowSubmission(GraphQLRequest):
137-
138137
query = """
139138
mutation workflowSubmissionMutation({signature}) {{
140139
{mutation_name}({args}) {{
@@ -180,9 +179,9 @@ class _WorkflowSubmission(GraphQLRequest):
180179
}
181180

182181
def __init__(
183-
self,
184-
detailed_response: bool,
185-
**kwargs,
182+
self,
183+
detailed_response: bool,
184+
**kwargs,
186185
):
187186
self.workflow_id = kwargs["workflow_id"]
188187
self.record_submission = kwargs["record_submission"]
@@ -263,28 +262,33 @@ class WorkflowSubmission(RequestChain):
263262
detailed_response = False
264263

265264
def __init__(
266-
self,
267-
workflow_id: int,
268-
files: List[str] = None,
269-
urls: List[str] = None,
270-
submission: bool = True,
271-
bundle: bool = False,
272-
result_version: str = None,
273-
streams: Dict[str, io.BufferedIOBase] = None
265+
self,
266+
workflow_id: int,
267+
files: List[str] = None,
268+
urls: List[str] = None,
269+
submission: bool = True,
270+
bundle: bool = False,
271+
result_version: str = None,
272+
streams: Dict[str, io.BufferedIOBase] = None
274273
):
275274
self.workflow_id = workflow_id
276275
self.files = files
277276
self.urls = urls
278277
self.submission = submission
279278
self.bundle = bundle
280279
self.result_version = result_version
281-
self.streams = streams.copy()
282-
283-
if not self.files and not self.urls and not len(streams) > 0:
280+
self.has_streams = False
281+
if streams is not None:
282+
self.streams = streams.copy()
283+
self.has_streams = True
284+
else:
285+
self.streams = None
286+
287+
if not self.files and not self.urls and not self.has_streams:
284288
raise IndicoInputError("One of 'files', 'streams', or 'urls' must be specified")
285-
elif self.files and len(self.streams) > 0:
289+
elif self.files and self.has_streams:
286290
raise IndicoInputError("Only one of 'files' or 'streams' or 'urls' may be specified.")
287-
elif (self.files or len(streams) > 0) and self.urls:
291+
elif (self.files or self.has_streams) and self.urls:
288292
raise IndicoInputError("Only one of 'files' or 'streams' or 'urls' may be specified")
289293

290294
def requests(self):
@@ -307,7 +311,7 @@ def requests(self):
307311
bundle=self.bundle,
308312
result_version=self.result_version,
309313
)
310-
elif len(self.streams) > 0:
314+
elif self.has_streams:
311315
yield UploadDocument(streams=self.streams)
312316
yield _WorkflowSubmission(
313317
self.detailed_response,
@@ -343,12 +347,12 @@ class WorkflowSubmissionDetailed(WorkflowSubmission):
343347
detailed_response = True
344348

345349
def __init__(
346-
self,
347-
workflow_id: int,
348-
files: List[str] = None,
349-
urls: List[str] = None,
350-
bundle: bool = False,
351-
result_version: str = None,
350+
self,
351+
workflow_id: int,
352+
files: List[str] = None,
353+
urls: List[str] = None,
354+
bundle: bool = False,
355+
result_version: str = None,
352356
):
353357
super().__init__(
354358
workflow_id,

indico/types/document_report.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class DocumentReport(BaseType):
1212
"""
1313
A Document report about the associated InputFiles.
1414
15-
15+
1616
"""
1717
dataset_id: int
1818
workflow_id: int

tests/integration/queries/test_workflow.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import io
2+
13
from indico.queries.workflow import GetWorkflow
24
import pytest
35
from pathlib import Path
@@ -96,6 +98,39 @@ def test_workflow_submission(
9698
assert isinstance(sub, Submission)
9799
assert sub.retrieved is True
98100

101+
def test_workflow_submission_with_streams(
102+
indico, airlines_dataset, airlines_model_group: ModelGroup
103+
):
104+
client = IndicoClient()
105+
wfs = client.call(ListWorkflows(dataset_ids=[airlines_dataset.id]))
106+
wf = max(wfs, key=lambda w: w.id)
107+
108+
path = Path(str(Path(__file__).parents[1]) + "/data/mock.pdf")
109+
fd = open(path.absolute(), "rb")
110+
files = {
111+
"mock.pdf": fd
112+
}
113+
submission_ids = client.call(WorkflowSubmission(workflow_id=wf.id, streams=files))
114+
submission_id = submission_ids[0]
115+
assert submission_id is not None
116+
117+
with pytest.raises(IndicoInputError):
118+
client.call(SubmissionResult(submission_id, "FAILED"))
119+
120+
with pytest.raises(IndicoInputError):
121+
client.call(SubmissionResult(submission_id, "INVALID_STATUS"))
122+
123+
result_url = client.call(SubmissionResult(submission_id, "COMPLETE", wait=True))
124+
result = client.call(RetrieveStorageObject(result_url.result))
125+
assert isinstance(result, dict)
126+
assert result["submission_id"] == submission_id
127+
assert result["file_version"] == 1
128+
client.call(UpdateSubmission(submission_id, retrieved=True))
129+
sub = client.call(GetSubmission(submission_id))
130+
assert isinstance(sub, Submission)
131+
assert sub.retrieved is True
132+
133+
99134

100135
@pytest.mark.parametrize(
101136
"_input",

0 commit comments

Comments
 (0)