1+ """Tests for streaming support functionality."""
2+
3+ import pytest
4+ from gradient ._utils import StreamProcessor , StreamCollector
5+
6+
7+ class TestStreamProcessor :
8+ """Test stream processor functionality."""
9+
10+ def test_stream_processor_basic (self ):
11+ """Test basic stream processor functionality."""
12+ processor = StreamProcessor ()
13+
14+ # Add handler
15+ results = []
16+ def text_handler (event ):
17+ results .append (f"processed: { event .get ('text' , '' )} " )
18+ return f"processed: { event .get ('text' , '' )} "
19+
20+ processor .add_handler ("text" , text_handler )
21+
22+ # Process events
23+ event1 = {"type" : "text" , "text" : "Hello" }
24+ event2 = {"type" : "other" , "data" : "ignored" }
25+ event3 = {"type" : "text" , "text" : "World" }
26+
27+ result1 = processor .process_event (event1 )
28+ result2 = processor .process_event (event2 )
29+ result3 = processor .process_event (event3 )
30+
31+ assert result1 == "processed: Hello"
32+ assert result2 is None # No handler for "other"
33+ assert result3 == "processed: World"
34+ assert results == ["processed: Hello" , "processed: World" ]
35+
36+ def test_stream_processor_remove_handler (self ):
37+ """Test removing event handlers."""
38+ processor = StreamProcessor ()
39+
40+ def handler (event ):
41+ return "handled"
42+
43+ processor .add_handler ("test" , handler )
44+ assert processor .process_event ({"type" : "test" }) == "handled"
45+
46+ processor .remove_handler ("test" )
47+ assert processor .process_event ({"type" : "test" }) is None
48+
49+ def test_stream_processor_process_stream (self ):
50+ """Test processing entire stream."""
51+ processor = StreamProcessor ()
52+
53+ def text_handler (event ):
54+ return event .get ("text" , "" ).upper ()
55+
56+ processor .add_handler ("text" , text_handler )
57+
58+ stream = [
59+ {"type" : "text" , "text" : "hello" },
60+ {"type" : "other" , "data" : "ignored" },
61+ {"type" : "text" , "text" : "world" }
62+ ]
63+
64+ results = processor .process_stream (stream )
65+ assert results == ["HELLO" , "WORLD" ]
66+
67+ def test_stream_processor_event_type_extraction (self ):
68+ """Test event type extraction from different formats."""
69+ processor = StreamProcessor ()
70+
71+ # Test different event formats
72+ event1 = {"type" : "custom" }
73+ event2 = type ('MockEvent' , (), {'event' : 'mock' })()
74+ event3 = {"event" : "dict_event" }
75+ event4 = "unknown_format"
76+
77+ assert processor ._get_event_type (event1 ) == "custom"
78+ assert processor ._get_event_type (event2 ) == "mock"
79+ assert processor ._get_event_type (event3 ) == "dict_event"
80+ assert processor ._get_event_type (event4 ) == "unknown"
81+
82+
83+ class TestStreamCollector :
84+ """Test stream collector functionality."""
85+
86+ def test_stream_collector_basic (self ):
87+ """Test basic stream collector functionality."""
88+ collector = StreamCollector ()
89+
90+ # Collect events
91+ event1 = {"type" : "text" , "text" : "Hello" }
92+ event2 = {"type" : "text" , "text" : "World" }
93+ event3 = {"type" : "error" , "message" : "Something went wrong" }
94+
95+ collector .collect (event1 )
96+ collector .collect (event2 )
97+ collector .collect (event3 )
98+
99+ # Check all events
100+ all_events = collector .get_events ()
101+ assert len (all_events ) == 3
102+
103+ # Check filtered events
104+ text_events = collector .get_events ("text" )
105+ assert len (text_events ) == 2
106+ assert all (e ["type" ] == "text" for e in text_events )
107+
108+ error_events = collector .get_events ("error" )
109+ assert len (error_events ) == 1
110+ assert error_events [0 ]["type" ] == "error"
111+
112+ def test_stream_collector_aggregation (self ):
113+ """Test event aggregation."""
114+ collector = StreamCollector ()
115+
116+ # Collect events
117+ collector .collect ({"type" : "text" , "text" : "Hello" })
118+ collector .collect ({"type" : "text" , "text" : "World" })
119+ collector .collect ({"type" : "error" , "message" : "Error 1" })
120+ collector .collect ({"type" : "error" , "message" : "Error 2" })
121+ collector .collect ({"type" : "text" , "text" : "Again" })
122+
123+ aggregated = collector .get_aggregated ()
124+
125+ # Check text events aggregation
126+ assert aggregated ["text" ]["count" ] == 3
127+ assert len (aggregated ["text" ]["events" ]) == 3
128+ assert aggregated ["text" ]["last_event" ]["text" ] == "Again"
129+
130+ # Check error events aggregation
131+ assert aggregated ["error" ]["count" ] == 2
132+ assert len (aggregated ["error" ]["events" ]) == 2
133+ assert aggregated ["error" ]["last_event" ]["message" ] == "Error 2"
134+
135+ def test_stream_collector_count_events (self ):
136+ """Test event counting."""
137+ collector = StreamCollector ()
138+
139+ collector .collect ({"type" : "text" })
140+ collector .collect ({"type" : "text" })
141+ collector .collect ({"type" : "error" })
142+ collector .collect ({"type" : "text" })
143+
144+ assert collector .count_events () == 4
145+ assert collector .count_events ("text" ) == 3
146+ assert collector .count_events ("error" ) == 1
147+ assert collector .count_events ("unknown" ) == 0
148+
149+ def test_stream_collector_clear (self ):
150+ """Test clearing collected events."""
151+ collector = StreamCollector ()
152+
153+ collector .collect ({"type" : "text" })
154+ collector .collect ({"type" : "error" })
155+
156+ assert collector .count_events () == 2
157+ assert len (collector .get_aggregated ()) == 2
158+
159+ collector .clear ()
160+
161+ assert collector .count_events () == 0
162+ assert len (collector .get_aggregated ()) == 0
0 commit comments