diff --git a/cdmtaskservice/job_state.py b/cdmtaskservice/job_state.py index 711f118..2bee0c4 100644 --- a/cdmtaskservice/job_state.py +++ b/cdmtaskservice/job_state.py @@ -310,16 +310,20 @@ async def stream_job_logs( job_id: str, container_num: int, user: CTSUser, - stderr: bool = False + stderr: bool = False, + seek: int | None = None, + length: int | None = None, ) -> tuple[AsyncIterator[bytes], str]: """ Stream the container logs from a job. - + job_id - the job's ID. container_num - the container number for which to retrieve logs. user - the user requesting the logs. stderr - return the stderr logs instead of the stdout logs. - + seek - the byte offset in the file from which to start reading. + length - the number of bytes to read from the file. + Returns a tuple of a generator that will stream the logfile and the name of the file. """ job = await self.get_job(job_id, user) @@ -332,7 +336,25 @@ async def stream_job_logs( ) filename = s3errpath if stderr else s3outpath s3path = S3Paths([str(Path(job.logpath) / filename)]) - return self._s3.stream_object(s3path), filename + + if seek is not None: + if seek < 0: + raise IllegalParameterError( + f"Seek parameter must be >= 0, got {seek}" + ) + # Only fetch file metadata if seek is non-zero + if seek > 0: + meta = await self._s3.get_object_meta(s3path) + if seek >= meta[0].size: + raise IllegalParameterError( + f"Seek parameter {seek} is >= file size {meta[0].size}" + ) + if length is not None and length < 1: + raise IllegalParameterError( + f"Length parameter must be >= 1, got {length}" + ) + + return self._s3.stream_object(s3path, seek=seek, length=length), filename async def resend_job_notification(self, job_id: str, user: CTSUser, state: models.JobState): """ diff --git a/cdmtaskservice/routes.py b/cdmtaskservice/routes.py index f49cd4d..d4970b6 100644 --- a/cdmtaskservice/routes.py +++ b/cdmtaskservice/routes.py @@ -344,6 +344,14 @@ async def get_job_exit_codes( description="The container / subjob number.", ge=0, )] +_ANN_LOG_SEEK = Annotated[int, Query( + description="The byte offset in the file from which to start reading.", + ge=0, +)] +_ANN_LOG_LENGTH = Annotated[int, Query( + description="The number of bytes to read from the file.", + ge=1, +)] @ROUTER_JOBS.get( @@ -360,8 +368,10 @@ async def get_job_stdout( job_id: _ANN_JOB_ID, container_num: _ANN_CONTAINER_NUMBER, user: CTSUser=Depends(_AUTH), + seek: _ANN_LOG_SEEK = None, + length: _ANN_LOG_LENGTH = None, ) -> StreamingResponse: - return await get_logs(r, job_id, container_num, user) + return await get_logs(r, job_id, container_num, user, seek=seek, length=length) @ROUTER_JOBS.get( @@ -378,8 +388,10 @@ async def get_job_stderr( job_id: _ANN_JOB_ID, container_num: _ANN_CONTAINER_NUMBER, user: CTSUser=Depends(_AUTH), + seek: _ANN_LOG_SEEK = None, + length: _ANN_LOG_LENGTH = None, ) -> StreamingResponse: - return await get_logs(r, job_id, container_num, user, stderr=True) + return await get_logs(r, job_id, container_num, user, stderr=True, seek=seek, length=length) async def get_logs( @@ -388,9 +400,13 @@ async def get_logs( container_num: int, user: CTSUser, stderr: bool = False, + seek: int | None = None, + length: int | None = None, ) -> StreamingResponse: jobstate = app_state.get_app_state(r).job_state - filegen, filename = await jobstate.stream_job_logs(job_id, container_num, user, stderr=stderr) + filegen, filename = await jobstate.stream_job_logs( + job_id, container_num, user, stderr=stderr, seek=seek, length=length + ) return StreamingResponse( filegen, media_type="text/plain; charset=utf-8", diff --git a/cdmtaskservice/s3/client.py b/cdmtaskservice/s3/client.py index e936c3a..5ea240d 100644 --- a/cdmtaskservice/s3/client.py +++ b/cdmtaskservice/s3/client.py @@ -291,25 +291,46 @@ async def head(client, buk=buk, key=key): # bind the current value of the varia )) return ret - def stream_object(self, s3path: S3Paths) -> AsyncIterator[bytes]: + def stream_object( + self, s3path: S3Paths, seek: int | None = None, length: int | None = None + ) -> AsyncIterator[bytes]: """ Stream an object from S3. - + Note that the S3Paths input must contain only a single path. + + s3path - the path of the object to stream. + seek - the byte offset in the file from which to start reading. + length - the number of bytes to read from the file. """ if len(_not_falsy(s3path, "s3path")) > 1: raise ValueError("Only one path is allowed") - return self._stream_object_generator(s3path) + if seek is not None: # TODO CODE add an optional param to _check_num + _check_num(seek, "seek", minimum=0) + if length is not None: + _check_num(length, "length", minimum=1) + return self._stream_object_generator(s3path, seek=seek, length=length) - async def _stream_object_generator(self, s3path: S3Paths) -> AsyncIterator[bytes]: + async def _stream_object_generator( + self, s3path: S3Paths, seek: int | None = None, length: int | None = None + ) -> AsyncIterator[bytes]: # internal helper returns the raw response # we can't use the _run_commands helper method here since the client needs to stay # open while the result is streamed - try: + try: async with self._client() as client: buk, key, path = next(s3path.split_paths(include_full_path=True)) async def go(client, buk=buk, key=key): - return await client.get_object(Bucket=buk, Key=key) + kwargs = {"Bucket": buk, "Key": key} + # Build Range header if seek or length is specified + if seek is not None or length is not None: + start = seek if seek is not None else 0 + if length is not None: + end = start + length - 1 + kwargs["Range"] = f"bytes={start}-{end}" + else: + kwargs["Range"] = f"bytes={start}-" + return await client.get_object(**kwargs) go.path = path res = await self._fnc_wrapper(client, go) body = res["Body"] diff --git a/test/s3/s3_client_test.py b/test/s3/s3_client_test.py index 48a48e1..e43d345 100644 --- a/test/s3/s3_client_test.py +++ b/test/s3/s3_client_test.py @@ -422,6 +422,12 @@ async def _download_objects_to_file_fail( assert_exception_correct(got.value, expected, print_stacktrace) +async def _stream_and_assert(s3c, file_path, expected, seek=None, length=None): + asynciter = s3c.stream_object(S3Paths([file_path]), seek=seek, length=length) + data = b"".join([chunk async for chunk in asynciter]) + assert data == expected + + @pytest.mark.asyncio async def test_stream_object(minio): await minio.clean() # couldn't get this to work as a fixture @@ -434,13 +440,17 @@ async def test_stream_object(minio): ) s3c = await _client(minio) - asynciter = s3c.stream_object(S3Paths(["test-bucket/test_file"])) - data = b"".join([chunk async for chunk in asynciter]) - assert data == b"imsounique" - - asynciter = s3c.stream_object(S3Paths(["test-bucket/big_test_file"])) - data = b"".join([chunk async for chunk in asynciter]) - assert data == b"abcdefghij" * 600000 * 2 + await _stream_and_assert(s3c, "test-bucket/test_file", b"imsounique") + await _stream_and_assert(s3c, "test-bucket/big_test_file", b"abcdefghij" * 600000 * 2) + + # Test with seek only - "imsounique"[5:] = "nique" + await _stream_and_assert(s3c, "test-bucket/test_file", b"nique", seek=5) + + # Test with length only - "imsounique"[:5] = "imsou" + await _stream_and_assert(s3c, "test-bucket/test_file", b"imsou", length=5) + + # Test with both seek and length - "imsounique"[3:8] = "ouniq" + await _stream_and_assert(s3c, "test-bucket/test_file", b"ouniq", seek=3, length=5) @pytest.mark.asyncio @@ -452,6 +462,22 @@ async def test_stream_object_fail_bad_input(minio): "Only one path is allowed" )) + # Test with invalid seek values + await _stream_object_fail( + s3c, S3Paths(["foo/bar"]), ValueError("seek must be >= 0"), seek=-1 + ) + await _stream_object_fail( + s3c, S3Paths(["foo/bar"]), ValueError("seek must be >= 0"), seek=-100 + ) + + # Test with invalid length values + await _stream_object_fail( + s3c, S3Paths(["foo/bar"]), ValueError("length must be >= 1"), length=0 + ) + await _stream_object_fail( + s3c, S3Paths(["foo/bar"]), ValueError("length must be >= 1"), length=-1 + ) + @pytest.mark.asyncio async def test_stream_object_fail_bad_connection(minio): @@ -501,10 +527,10 @@ async def test_stream_object_fail_unauthed(minio, minio_unauthed_user): async def _stream_object_fail( - cli: S3Client, path: S3Paths, expected: Exception, print_stacktrace=False + cli: S3Client, path: S3Paths, expected: Exception, seek=None, length=None, print_stacktrace=False ): with pytest.raises(Exception) as got: - res = cli.stream_object(path) + res = cli.stream_object(path, seek=seek, length=length) [chunk async for chunk in res] assert_exception_correct(got.value, expected, print_stacktrace)