Skip to content

Commit e065976

Browse files
committed
feat: raise Exception when an SSE error event occurrs
1 parent aed6db8 commit e065976

File tree

9 files changed

+362
-9
lines changed

9 files changed

+362
-9
lines changed

src/aiperf/common/enums/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ServiceType,
9292
)
9393
from aiperf.common.enums.sse_enums import (
94+
SSEEventType,
9495
SSEFieldType,
9596
)
9697
from aiperf.common.enums.system_enums import (
@@ -155,6 +156,7 @@
155156
"RecordProcessorType",
156157
"RequestRateMode",
157158
"ResultsProcessorType",
159+
"SSEEventType",
158160
"SSEFieldType",
159161
"ServiceRegistrationStatus",
160162
"ServiceRunType",

src/aiperf/common/enums/sse_enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ class SSEFieldType(CaseInsensitiveStrEnum):
1212
ID = "id"
1313
RETRY = "retry"
1414
COMMENT = "comment"
15+
16+
17+
class SSEEventType(CaseInsensitiveStrEnum):
18+
"""Event types in an SSE message."""
19+
20+
ERROR = "error"

src/aiperf/common/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ class ShutdownError(AIPerfError):
167167
"""Exception raised when a service encounters an error while shutting down."""
168168

169169

170+
class SSEResponseError(AIPerfError):
171+
"""Exception raised when a SSE response contains an error."""
172+
173+
def __init__(self, message: str, error_code: int = 500) -> None:
174+
self.error_code = error_code
175+
super().__init__(message)
176+
177+
170178
class UnsupportedHookError(AIPerfError):
171179
"""Exception raised when a hook is defined on a class that does not have any base classes that provide that hook type."""
172180

src/aiperf/common/models/error_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,15 @@ def __hash__(self) -> int:
6565
@classmethod
6666
def from_exception(cls, e: BaseException) -> "ErrorDetails":
6767
"""Create an error details object from an exception."""
68-
return cls(
68+
error_details = cls(
6969
type=e.__class__.__name__,
7070
message=cls._safe_repr(e),
7171
cause=cls._safe_repr(e.__cause__) if e.__cause__ else None,
7272
details=[cls._safe_repr(arg) for arg in e.args] if e.args else None,
7373
)
74+
if hasattr(e, "error_code") and isinstance(e.error_code, int):
75+
error_details.code = e.error_code
76+
return error_details
7477

7578

7679
class ExitErrorInfo(AIPerfBaseModel):

src/aiperf/transports/aiohttp_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import aiohttp
88

9+
from aiperf.common.exceptions import SSEResponseError
910
from aiperf.common.mixins import AIPerfLoggerMixin
1011
from aiperf.common.models import (
1112
ErrorDetails,
@@ -102,6 +103,7 @@ async def _request(
102103
):
103104
# Parse SSE stream with optimal performance
104105
async for message in AsyncSSEStreamReader(response.content):
106+
AsyncSSEStreamReader.inspect_message_for_error(message)
105107
record.responses.append(message)
106108
else:
107109
raw_response = await response.text()
@@ -114,7 +116,10 @@ async def _request(
114116
)
115117
)
116118
record.end_perf_ns = time.perf_counter_ns()
117-
119+
except SSEResponseError as e:
120+
record.end_perf_ns = time.perf_counter_ns()
121+
self.error(f"Error in SSE response: {e!r}")
122+
record.error = ErrorDetails.from_exception(e)
118123
except Exception as e:
119124
record.end_perf_ns = time.perf_counter_ns()
120125
self.error(f"Error in aiohttp request: {e!r}")

src/aiperf/transports/sse_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from collections.abc import AsyncIterator
77

88
from aiperf.common.aiperf_logger import AIPerfLogger
9+
from aiperf.common.enums.sse_enums import SSEEventType, SSEFieldType
10+
from aiperf.common.exceptions import SSEResponseError
911
from aiperf.common.models import SSEMessage
1012

1113
_logger = AIPerfLogger(__name__)
@@ -71,9 +73,38 @@ async def read_complete_stream(self) -> list[SSEMessage]:
7173
"""Read the complete SSE stream and return a list of SSE messages."""
7274
messages: list[SSEMessage] = []
7375
async for message in self:
76+
AsyncSSEStreamReader.inspect_message_for_error(message)
7477
messages.append(message)
7578
return messages
7679

80+
@staticmethod
81+
def inspect_message_for_error(message: SSEMessage):
82+
"""Check if the message contains an error event packet and raise an SSEResponseError if so.
83+
84+
If so, look for any comment field and raise an SSEResponseError
85+
with that comment as the error message, otherwise use the full message.
86+
"""
87+
has_error_event = any(
88+
packet.name == SSEFieldType.EVENT and packet.value == SSEEventType.ERROR
89+
for packet in message.packets
90+
)
91+
92+
if has_error_event:
93+
error_message = None
94+
for packet in message.packets:
95+
if packet.name == SSEFieldType.COMMENT:
96+
error_message = packet.value
97+
break
98+
99+
if error_message is None:
100+
error_message = (
101+
f"Unknown error in SSE response: {message.model_dump_json()}"
102+
)
103+
104+
raise SSEResponseError(
105+
f"Error occurred in SSE response: {error_message}", error_code=502
106+
)
107+
77108
async def __aiter__(self) -> AsyncIterator[SSEMessage]:
78109
"""Iterate over the SSE stream in a performant manner and yield parsed SSE messages as they arrive."""
79110

tests/transports/test_aiohttp_client.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,74 @@ async def mock_aiter():
120120
record, expected_response_count=2, expected_response_type=SSEMessage
121121
)
122122

123+
@pytest.mark.asyncio
124+
@pytest.mark.parametrize(
125+
"comment_value,expected_error_text",
126+
[
127+
("Rate limit exceeded", "Rate limit exceeded"),
128+
(None, "Unknown error in SSE response"),
129+
],
130+
)
131+
async def test_sse_stream_error_event_handling(
132+
self,
133+
aiohttp_client: AioHttpClient,
134+
mock_sse_response: Mock,
135+
comment_value: str | None,
136+
expected_error_text: str,
137+
) -> None:
138+
"""Test that SSE error events are properly caught and handled in the client."""
139+
from aiperf.common.enums import SSEEventType, SSEFieldType
140+
from aiperf.common.models import SSEField
141+
142+
packets = [
143+
SSEField(name=SSEFieldType.EVENT, value=SSEEventType.ERROR),
144+
]
145+
if comment_value:
146+
packets.append(SSEField(name=SSEFieldType.COMMENT, value=comment_value))
147+
packets.append(SSEField(name=SSEFieldType.DATA, value="{}"))
148+
149+
mock_error_message = SSEMessage(perf_ns=123456789, packets=packets)
150+
151+
with (
152+
patch("aiohttp.ClientSession") as mock_session_class,
153+
patch(
154+
"aiperf.transports.aiohttp_client.AsyncSSEStreamReader"
155+
) as mock_reader_class,
156+
):
157+
158+
async def mock_content_iter():
159+
yield b"event: error\n"
160+
if comment_value:
161+
yield f": {comment_value}\n".encode()
162+
yield b"data: {}\n\n"
163+
164+
mock_sse_response.content = mock_content_iter()
165+
166+
setup_mock_session(mock_session_class, mock_sse_response, ["request"])
167+
168+
async def mock_aiter():
169+
yield mock_error_message
170+
from aiperf.transports.sse_utils import AsyncSSEStreamReader
171+
172+
AsyncSSEStreamReader.inspect_message_for_error(mock_error_message)
173+
174+
mock_reader = Mock()
175+
mock_reader.__aiter__ = Mock(return_value=mock_aiter())
176+
mock_reader_class.return_value = mock_reader
177+
178+
record = await aiohttp_client.post_request(
179+
"http://test.com/stream",
180+
'{"stream": true}',
181+
{"Accept": "text/event-stream"},
182+
)
183+
184+
assert record.error is not None
185+
assert record.error.code == 502
186+
assert record.error.type == "SSEResponseError"
187+
assert expected_error_text in record.error.message
188+
assert len(record.responses) == 1
189+
assert isinstance(record.responses[0], SSEMessage)
190+
123191
@pytest.mark.asyncio
124192
@pytest.mark.parametrize(
125193
"status_code,reason,error_text",
@@ -184,13 +252,9 @@ async def test_exception_handling(
184252
"exception_class,message,expected_type",
185253
[
186254
(aiohttp.ClientConnectorError, "Connection failed", "ClientConnectorError"),
187-
(
188-
aiohttp.ClientResponseError,
189-
"Internal Server Error",
190-
"ClientResponseError",
191-
),
255+
(aiohttp.ClientResponseError, "Internal Server Error", "ClientResponseError"),
192256
],
193-
)
257+
) # fmt: skip
194258
async def test_aiohttp_specific_exceptions(
195259
self,
196260
aiohttp_client: AioHttpClient,

tests/transports/test_aiohttp_sse.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytest
1010

11+
from aiperf.common.exceptions import SSEResponseError
1112
from aiperf.common.models import SSEMessage
1213
from aiperf.transports.sse_utils import AsyncSSEStreamReader
1314

@@ -448,7 +449,7 @@ async def test_aiter_crlf_all_field_types(self) -> None:
448449
"""Test __aiter__ with CRLF and all SSE field types."""
449450
chunks = [
450451
b"data: test\r\nevent: custom\r\nid: msg-123\r\nretry: 5000\r\n: comment\r\n\r\n"
451-
]
452+
] # fmt: skip
452453

453454
reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks))
454455
messages = await self._collect_messages(reader)
@@ -529,3 +530,47 @@ async def test_aiter_crlf_performance(self) -> None:
529530
assert processing_time < 3.0, (
530531
f"CRLF processing took {processing_time:.3f}s, expected < 3s"
531532
)
533+
534+
@pytest.mark.asyncio
535+
@pytest.mark.parametrize(
536+
"chunks,expected_error",
537+
[
538+
([b"data: Normal message\n\n", b"event: error\n: Rate limit\ndata: {}\n\n"], "Rate limit"),
539+
([b"event: error\ndata: Something went wrong\n\n"], "Unknown error in SSE response"),
540+
([b"event: error\r\n: Server error\r\ndata: {}\r\n\r\n"], "Server error"),
541+
([b"event: error\n: Connection timeout\n\n"], "Connection timeout"),
542+
([b"data: Message 1\n\n", b"data: Message 2\n\n", b"event: error\n: Fatal error\n\n"], "Fatal error"),
543+
([b'event: error\n: Internal error\ndata: {"error_code": 500}\n\n'], "Internal error"),
544+
],
545+
) # fmt: skip
546+
async def test_error_events_raise_in_read_complete_stream(
547+
self, chunks: list[bytes], expected_error: str
548+
) -> None:
549+
"""Test that various error events raise SSEResponseError."""
550+
reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks))
551+
552+
with pytest.raises(SSEResponseError) as exc_info:
553+
await reader.read_complete_stream()
554+
555+
assert expected_error in str(exc_info.value)
556+
assert exc_info.value.error_code == 502
557+
558+
@pytest.mark.asyncio
559+
async def test_error_in_manual_iteration_with_inspect(self) -> None:
560+
"""Test that manual iteration with inspect raises on error event."""
561+
chunks = [
562+
b"data: First message\n\n",
563+
b"event: error\n: Authentication failed\n\n",
564+
b"data: Should not reach\n\n",
565+
]
566+
567+
reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks))
568+
messages = []
569+
570+
with pytest.raises(SSEResponseError) as exc_info:
571+
async for message in reader:
572+
AsyncSSEStreamReader.inspect_message_for_error(message)
573+
messages.append(message)
574+
575+
assert len(messages) == 1
576+
assert "Authentication failed" in str(exc_info.value)

0 commit comments

Comments
 (0)