Skip to content

Commit 9f78805

Browse files
committed
feat: add streaming support utilities
1 parent 985ec36 commit 9f78805

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed

src/gradient/_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
BatchProcessor as BatchProcessor,
3838
DataExporter as DataExporter,
3939
PaginationHelper as PaginationHelper,
40+
StreamProcessor as StreamProcessor,
41+
StreamCollector as StreamCollector,
4042
)
4143
from ._compat import (
4244
get_args as get_args,

src/gradient/_utils/_utils.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,142 @@ async def paginate_async(self, fetch_func: Callable[[dict[str, Any]], Any], **kw
838838
return all_items
839839

840840

841+
# Streaming Support Classes
842+
class StreamProcessor:
843+
"""Utility for processing streaming API responses with custom handlers."""
844+
845+
def __init__(self) -> None:
846+
"""Initialize stream processor."""
847+
self._handlers: dict[str, Callable[[Any], Any]] = {}
848+
849+
def add_handler(self, event_type: str, handler: Callable[[Any], Any]) -> None:
850+
"""Add event handler for specific event type.
851+
852+
Args:
853+
event_type: Type of event to handle
854+
handler: Function to process the event
855+
"""
856+
self._handlers[event_type] = handler
857+
858+
def remove_handler(self, event_type: str) -> None:
859+
"""Remove handler for specific event type."""
860+
self._handlers.pop(event_type, None)
861+
862+
def process_event(self, event: Any) -> Any | None:
863+
"""Process a single streaming event.
864+
865+
Args:
866+
event: The event data to process
867+
868+
Returns:
869+
Result of handler if one exists, None otherwise
870+
"""
871+
event_type = self._get_event_type(event)
872+
handler = self._handlers.get(event_type)
873+
if handler:
874+
return handler(event)
875+
return None
876+
877+
def process_stream(self, stream: Any) -> list[Any]:
878+
"""Process entire streaming response.
879+
880+
Args:
881+
stream: The stream to process
882+
883+
Returns:
884+
List of all processed event results
885+
"""
886+
results = []
887+
for event in stream:
888+
result = self.process_event(event)
889+
if result is not None:
890+
results.append(result)
891+
return results
892+
893+
async def process_stream_async(self, stream: Any) -> list[Any]:
894+
"""Async version of process_stream."""
895+
results = []
896+
async for event in stream:
897+
result = self.process_event(event)
898+
if result is not None:
899+
results.append(result)
900+
return results
901+
902+
def _get_event_type(self, event: Any) -> str:
903+
"""Extract event type from event data."""
904+
# Handle different event formats
905+
if hasattr(event, 'event') and event.event:
906+
return event.event
907+
elif hasattr(event, 'type') and event.type:
908+
return event.type
909+
elif isinstance(event, dict):
910+
return event.get('event') or event.get('type') or 'unknown'
911+
else:
912+
return 'unknown'
913+
914+
915+
class StreamCollector:
916+
"""Utility for collecting and aggregating streaming events."""
917+
918+
def __init__(self) -> None:
919+
"""Initialize stream collector."""
920+
self._events: list[Any] = []
921+
self._aggregated: dict[str, Any] = {}
922+
923+
def collect(self, event: Any) -> None:
924+
"""Collect a streaming event."""
925+
self._events.append(event)
926+
self._aggregate_event(event)
927+
928+
def get_events(self, event_type: str | None = None) -> list[Any]:
929+
"""Get collected events, optionally filtered by type."""
930+
if event_type is None:
931+
return self._events.copy()
932+
933+
return [e for e in self._events if self._get_event_type(e) == event_type]
934+
935+
def get_aggregated(self) -> dict[str, Any]:
936+
"""Get aggregated event data."""
937+
return self._aggregated.copy()
938+
939+
def clear(self) -> None:
940+
"""Clear all collected events and aggregated data."""
941+
self._events.clear()
942+
self._aggregated.clear()
943+
944+
def count_events(self, event_type: str | None = None) -> int:
945+
"""Count events, optionally filtered by type."""
946+
if event_type is None:
947+
return len(self._events)
948+
return len(self.get_events(event_type))
949+
950+
def _aggregate_event(self, event: Any) -> None:
951+
"""Aggregate event data for summary statistics."""
952+
event_type = self._get_event_type(event)
953+
954+
if event_type not in self._aggregated:
955+
self._aggregated[event_type] = {
956+
'count': 0,
957+
'events': [],
958+
'last_event': None
959+
}
960+
961+
self._aggregated[event_type]['count'] += 1
962+
self._aggregated[event_type]['events'].append(event)
963+
self._aggregated[event_type]['last_event'] = event
964+
965+
def _get_event_type(self, event: Any) -> str:
966+
"""Extract event type from event data."""
967+
if hasattr(event, 'event') and event.event:
968+
return event.event
969+
elif hasattr(event, 'type') and event.type:
970+
return event.type
971+
elif isinstance(event, dict):
972+
return event.get('event') or event.get('type') or 'unknown'
973+
else:
974+
return 'unknown'
975+
976+
841977
# API Key Validation Functions
842978
def validate_api_key(api_key: str | None) -> bool:
843979
"""Validate an API key format.

tests/test_streaming_support.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Tests for streaming support functionality."""
2+
3+
import pytest
4+
from gradient._utils import StreamProcessor, StreamCollector
5+
6+
7+
class TestStreamProcessor:
8+
"""Test stream processor functionality."""
9+
10+
def test_stream_processor_basic(self):
11+
"""Test basic stream processor functionality."""
12+
processor = StreamProcessor()
13+
14+
# Add handler
15+
results = []
16+
def text_handler(event):
17+
results.append(f"processed: {event.get('text', '')}")
18+
return f"processed: {event.get('text', '')}"
19+
20+
processor.add_handler("text", text_handler)
21+
22+
# Process events
23+
event1 = {"type": "text", "text": "Hello"}
24+
event2 = {"type": "other", "data": "ignored"}
25+
event3 = {"type": "text", "text": "World"}
26+
27+
result1 = processor.process_event(event1)
28+
result2 = processor.process_event(event2)
29+
result3 = processor.process_event(event3)
30+
31+
assert result1 == "processed: Hello"
32+
assert result2 is None # No handler for "other"
33+
assert result3 == "processed: World"
34+
assert results == ["processed: Hello", "processed: World"]
35+
36+
def test_stream_processor_remove_handler(self):
37+
"""Test removing event handlers."""
38+
processor = StreamProcessor()
39+
40+
def handler(event):
41+
return "handled"
42+
43+
processor.add_handler("test", handler)
44+
assert processor.process_event({"type": "test"}) == "handled"
45+
46+
processor.remove_handler("test")
47+
assert processor.process_event({"type": "test"}) is None
48+
49+
def test_stream_processor_process_stream(self):
50+
"""Test processing entire stream."""
51+
processor = StreamProcessor()
52+
53+
def text_handler(event):
54+
return event.get("text", "").upper()
55+
56+
processor.add_handler("text", text_handler)
57+
58+
stream = [
59+
{"type": "text", "text": "hello"},
60+
{"type": "other", "data": "ignored"},
61+
{"type": "text", "text": "world"}
62+
]
63+
64+
results = processor.process_stream(stream)
65+
assert results == ["HELLO", "WORLD"]
66+
67+
def test_stream_processor_event_type_extraction(self):
68+
"""Test event type extraction from different formats."""
69+
processor = StreamProcessor()
70+
71+
# Test different event formats
72+
event1 = {"type": "custom"}
73+
event2 = type('MockEvent', (), {'event': 'mock'})()
74+
event3 = {"event": "dict_event"}
75+
event4 = "unknown_format"
76+
77+
assert processor._get_event_type(event1) == "custom"
78+
assert processor._get_event_type(event2) == "mock"
79+
assert processor._get_event_type(event3) == "dict_event"
80+
assert processor._get_event_type(event4) == "unknown"
81+
82+
83+
class TestStreamCollector:
84+
"""Test stream collector functionality."""
85+
86+
def test_stream_collector_basic(self):
87+
"""Test basic stream collector functionality."""
88+
collector = StreamCollector()
89+
90+
# Collect events
91+
event1 = {"type": "text", "text": "Hello"}
92+
event2 = {"type": "text", "text": "World"}
93+
event3 = {"type": "error", "message": "Something went wrong"}
94+
95+
collector.collect(event1)
96+
collector.collect(event2)
97+
collector.collect(event3)
98+
99+
# Check all events
100+
all_events = collector.get_events()
101+
assert len(all_events) == 3
102+
103+
# Check filtered events
104+
text_events = collector.get_events("text")
105+
assert len(text_events) == 2
106+
assert all(e["type"] == "text" for e in text_events)
107+
108+
error_events = collector.get_events("error")
109+
assert len(error_events) == 1
110+
assert error_events[0]["type"] == "error"
111+
112+
def test_stream_collector_aggregation(self):
113+
"""Test event aggregation."""
114+
collector = StreamCollector()
115+
116+
# Collect events
117+
collector.collect({"type": "text", "text": "Hello"})
118+
collector.collect({"type": "text", "text": "World"})
119+
collector.collect({"type": "error", "message": "Error 1"})
120+
collector.collect({"type": "error", "message": "Error 2"})
121+
collector.collect({"type": "text", "text": "Again"})
122+
123+
aggregated = collector.get_aggregated()
124+
125+
# Check text events aggregation
126+
assert aggregated["text"]["count"] == 3
127+
assert len(aggregated["text"]["events"]) == 3
128+
assert aggregated["text"]["last_event"]["text"] == "Again"
129+
130+
# Check error events aggregation
131+
assert aggregated["error"]["count"] == 2
132+
assert len(aggregated["error"]["events"]) == 2
133+
assert aggregated["error"]["last_event"]["message"] == "Error 2"
134+
135+
def test_stream_collector_count_events(self):
136+
"""Test event counting."""
137+
collector = StreamCollector()
138+
139+
collector.collect({"type": "text"})
140+
collector.collect({"type": "text"})
141+
collector.collect({"type": "error"})
142+
collector.collect({"type": "text"})
143+
144+
assert collector.count_events() == 4
145+
assert collector.count_events("text") == 3
146+
assert collector.count_events("error") == 1
147+
assert collector.count_events("unknown") == 0
148+
149+
def test_stream_collector_clear(self):
150+
"""Test clearing collected events."""
151+
collector = StreamCollector()
152+
153+
collector.collect({"type": "text"})
154+
collector.collect({"type": "error"})
155+
156+
assert collector.count_events() == 2
157+
assert len(collector.get_aggregated()) == 2
158+
159+
collector.clear()
160+
161+
assert collector.count_events() == 0
162+
assert len(collector.get_aggregated()) == 0

0 commit comments

Comments
 (0)