Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions cdmtaskservice/job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,20 @@ async def stream_job_logs(
job_id: str,
container_num: int,
user: CTSUser,
stderr: bool = False
stderr: bool = False,
seek: int | None = None,
length: int | None = None,
) -> tuple[AsyncIterator[bytes], str]:
"""
Stream the container logs from a job.

job_id - the job's ID.
container_num - the container number for which to retrieve logs.
user - the user requesting the logs.
stderr - return the stderr logs instead of the stdout logs.

seek - the byte offset in the file from which to start reading.
length - the number of bytes to read from the file.

Returns a tuple of a generator that will stream the logfile and the name of the file.
"""
job = await self.get_job(job_id, user)
Expand All @@ -332,7 +336,25 @@ async def stream_job_logs(
)
filename = s3errpath if stderr else s3outpath
s3path = S3Paths([str(Path(job.logpath) / filename)])
return self._s3.stream_object(s3path), filename

if seek is not None:
if seek < 0:
raise IllegalParameterError(
f"Seek parameter must be >= 0, got {seek}"
)
# Only fetch file metadata if seek is non-zero
if seek > 0:
meta = await self._s3.get_object_meta(s3path)
if seek >= meta[0].size:
raise IllegalParameterError(
f"Seek parameter {seek} is >= file size {meta[0].size}"
)
if length is not None and length < 1:
raise IllegalParameterError(
f"Length parameter must be >= 1, got {length}"
)

return self._s3.stream_object(s3path, seek=seek, length=length), filename

async def resend_job_notification(self, job_id: str, user: CTSUser, state: models.JobState):
"""
Expand Down
22 changes: 19 additions & 3 deletions cdmtaskservice/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@ async def get_job_exit_codes(
description="The container / subjob number.",
ge=0,
)]
_ANN_LOG_SEEK = Annotated[int, Query(
description="The byte offset in the file from which to start reading.",
ge=0,
)]
_ANN_LOG_LENGTH = Annotated[int, Query(
description="The number of bytes to read from the file.",
ge=1,
)]


@ROUTER_JOBS.get(
Expand All @@ -360,8 +368,10 @@ async def get_job_stdout(
job_id: _ANN_JOB_ID,
container_num: _ANN_CONTAINER_NUMBER,
user: CTSUser=Depends(_AUTH),
seek: _ANN_LOG_SEEK = None,
length: _ANN_LOG_LENGTH = None,
) -> StreamingResponse:
return await get_logs(r, job_id, container_num, user)
return await get_logs(r, job_id, container_num, user, seek=seek, length=length)


@ROUTER_JOBS.get(
Expand All @@ -378,8 +388,10 @@ async def get_job_stderr(
job_id: _ANN_JOB_ID,
container_num: _ANN_CONTAINER_NUMBER,
user: CTSUser=Depends(_AUTH),
seek: _ANN_LOG_SEEK = None,
length: _ANN_LOG_LENGTH = None,
) -> StreamingResponse:
return await get_logs(r, job_id, container_num, user, stderr=True)
return await get_logs(r, job_id, container_num, user, stderr=True, seek=seek, length=length)


async def get_logs(
Expand All @@ -388,9 +400,13 @@ async def get_logs(
container_num: int,
user: CTSUser,
stderr: bool = False,
seek: int | None = None,
length: int | None = None,
) -> StreamingResponse:
jobstate = app_state.get_app_state(r).job_state
filegen, filename = await jobstate.stream_job_logs(job_id, container_num, user, stderr=stderr)
filegen, filename = await jobstate.stream_job_logs(
job_id, container_num, user, stderr=stderr, seek=seek, length=length
)
return StreamingResponse(
filegen,
media_type="text/plain; charset=utf-8",
Expand Down
33 changes: 27 additions & 6 deletions cdmtaskservice/s3/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,25 +291,46 @@ async def head(client, buk=buk, key=key): # bind the current value of the varia
))
return ret

def stream_object(self, s3path: S3Paths) -> AsyncIterator[bytes]:
def stream_object(
self, s3path: S3Paths, seek: int | None = None, length: int | None = None
) -> AsyncIterator[bytes]:
"""
Stream an object from S3.

Note that the S3Paths input must contain only a single path.

s3path - the path of the object to stream.
seek - the byte offset in the file from which to start reading.
length - the number of bytes to read from the file.
"""
if len(_not_falsy(s3path, "s3path")) > 1:
raise ValueError("Only one path is allowed")
return self._stream_object_generator(s3path)
if seek is not None: # TODO CODE add an optional param to _check_num
_check_num(seek, "seek", minimum=0)
if length is not None:
_check_num(length, "length", minimum=1)
return self._stream_object_generator(s3path, seek=seek, length=length)

async def _stream_object_generator(self, s3path: S3Paths) -> AsyncIterator[bytes]:
async def _stream_object_generator(
self, s3path: S3Paths, seek: int | None = None, length: int | None = None
) -> AsyncIterator[bytes]:
# internal helper returns the raw response
# we can't use the _run_commands helper method here since the client needs to stay
# open while the result is streamed
try:
try:
async with self._client() as client:
buk, key, path = next(s3path.split_paths(include_full_path=True))
async def go(client, buk=buk, key=key):
return await client.get_object(Bucket=buk, Key=key)
kwargs = {"Bucket": buk, "Key": key}
# Build Range header if seek or length is specified
if seek is not None or length is not None:
start = seek if seek is not None else 0
if length is not None:
end = start + length - 1
kwargs["Range"] = f"bytes={start}-{end}"
else:
kwargs["Range"] = f"bytes={start}-"
return await client.get_object(**kwargs)
go.path = path
res = await self._fnc_wrapper(client, go)
body = res["Body"]
Expand Down
44 changes: 35 additions & 9 deletions test/s3/s3_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ async def _download_objects_to_file_fail(
assert_exception_correct(got.value, expected, print_stacktrace)


async def _stream_and_assert(s3c, file_path, expected, seek=None, length=None):
asynciter = s3c.stream_object(S3Paths([file_path]), seek=seek, length=length)
data = b"".join([chunk async for chunk in asynciter])
assert data == expected


@pytest.mark.asyncio
async def test_stream_object(minio):
await minio.clean() # couldn't get this to work as a fixture
Expand All @@ -434,13 +440,17 @@ async def test_stream_object(minio):
)

s3c = await _client(minio)
asynciter = s3c.stream_object(S3Paths(["test-bucket/test_file"]))
data = b"".join([chunk async for chunk in asynciter])
assert data == b"imsounique"

asynciter = s3c.stream_object(S3Paths(["test-bucket/big_test_file"]))
data = b"".join([chunk async for chunk in asynciter])
assert data == b"abcdefghij" * 600000 * 2
await _stream_and_assert(s3c, "test-bucket/test_file", b"imsounique")
await _stream_and_assert(s3c, "test-bucket/big_test_file", b"abcdefghij" * 600000 * 2)

# Test with seek only - "imsounique"[5:] = "nique"
await _stream_and_assert(s3c, "test-bucket/test_file", b"nique", seek=5)

# Test with length only - "imsounique"[:5] = "imsou"
await _stream_and_assert(s3c, "test-bucket/test_file", b"imsou", length=5)

# Test with both seek and length - "imsounique"[3:8] = "ouniq"
await _stream_and_assert(s3c, "test-bucket/test_file", b"ouniq", seek=3, length=5)


@pytest.mark.asyncio
Expand All @@ -452,6 +462,22 @@ async def test_stream_object_fail_bad_input(minio):
"Only one path is allowed"
))

# Test with invalid seek values
await _stream_object_fail(
s3c, S3Paths(["foo/bar"]), ValueError("seek must be >= 0"), seek=-1
)
await _stream_object_fail(
s3c, S3Paths(["foo/bar"]), ValueError("seek must be >= 0"), seek=-100
)

# Test with invalid length values
await _stream_object_fail(
s3c, S3Paths(["foo/bar"]), ValueError("length must be >= 1"), length=0
)
await _stream_object_fail(
s3c, S3Paths(["foo/bar"]), ValueError("length must be >= 1"), length=-1
)


@pytest.mark.asyncio
async def test_stream_object_fail_bad_connection(minio):
Expand Down Expand Up @@ -501,10 +527,10 @@ async def test_stream_object_fail_unauthed(minio, minio_unauthed_user):


async def _stream_object_fail(
cli: S3Client, path: S3Paths, expected: Exception, print_stacktrace=False
cli: S3Client, path: S3Paths, expected: Exception, seek=None, length=None, print_stacktrace=False
):
with pytest.raises(Exception) as got:
res = cli.stream_object(path)
res = cli.stream_object(path, seek=seek, length=length)
[chunk async for chunk in res]
assert_exception_correct(got.value, expected, print_stacktrace)

Expand Down
Loading