11from contextvars import ContextVar , copy_context
2+ import sys
23from typing import (
34 Any ,
45 AsyncIterator ,
2829)
2930
3031
32+ if sys .version_info .minor < 10 :
33+ from guardrails .utils .polyfills import anext
34+
35+
3136class AsyncStreamRunner (AsyncRunner , StreamRunner ):
3237 # @async_trace_stream(name="/reasks", origin="AsyncStreamRunner.async_run")
3338 async def async_run (
@@ -137,96 +142,124 @@ async def async_step(
137142
138143 if self .output_type == OutputTypes .STRING :
139144 validator_service = AsyncValidatorService (self .disable_tracer )
140- async for chunk in stream_output :
141- chunk_text = self .get_chunk_text (chunk , api )
142- _ = self .is_last_chunk (chunk , api )
143145
144- fragment += chunk_text
146+ next_exists = True
147+ while next_exists :
148+ try :
149+ chunk = await anext (stream_output )
150+ chunk_text = self .get_chunk_text (chunk , api )
151+ _ = self .is_last_chunk (chunk , api )
145152
146- results = await validator_service .async_partial_validate (
147- chunk_text ,
148- self .metadata ,
149- self .validation_map ,
150- iteration ,
151- "$" ,
152- "$" ,
153- True ,
154- context = context ,
155- context_vars = stream_context_vars ,
156- )
157- validators = self .validation_map .get ("$" , [])
153+ fragment += chunk_text
158154
159- # collect the result validated_chunk into validation progress
160- # per validator
161- for result in results :
162- validator_log = result . validator_logs # type: ignore
163- validator = next (
164- filter (
165- lambda x : x . rail_alias == validator_log . registered_name ,
166- validators ,
167- ) ,
168- None ,
155+ results = await validator_service . async_partial_validate (
156+ chunk_text ,
157+ self . metadata ,
158+ self . validation_map ,
159+ iteration ,
160+ "$" ,
161+ "$" ,
162+ True ,
163+ context = context ,
164+ context_vars = stream_context_vars ,
169165 )
170- if (
171- validator_log .validation_result
172- and validator_log .validation_result .validated_chunk
173- ):
174- is_filter = validator .on_fail_descriptor is OnFailAction .FILTER # type: ignore
175- is_refrain = (
176- validator .on_fail_descriptor is OnFailAction .REFRAIN # type: ignore
166+ validators = self .validation_map .get ("$" , [])
167+
168+ # collect the result validated_chunk into validation progress
169+ # per validator
170+ for result in results :
171+ validator_log = result .validator_logs # type: ignore
172+ validator = next (
173+ filter (
174+ lambda x : x .rail_alias == validator_log .registered_name ,
175+ validators ,
176+ ),
177+ None ,
177178 )
178- if validator_log .validation_result .outcome == "fail" :
179- validation_passed = False
180- reasks , valid_op = self .introspect (
179+ if (
181180 validator_log .validation_result
182- )
183- if reasks :
184- raise ValueError (
185- "Reasks are not yet supported with streaming. Please "
186- "remove reasks from schema or disable streaming."
181+ and validator_log .validation_result .validated_chunk
182+ ):
183+ is_filter = (
184+ validator .on_fail_descriptor is OnFailAction .FILTER # type: ignore
185+ )
186+ is_refrain = (
187+ validator .on_fail_descriptor is OnFailAction .REFRAIN # type: ignore
188+ )
189+ if validator_log .validation_result .outcome == "fail" :
190+ validation_passed = False
191+ reasks , valid_op = self .introspect (
192+ validator_log .validation_result
187193 )
194+ if reasks :
195+ raise ValueError (
196+ "Reasks are not yet supported with streaming. "
197+ "Please remove reasks from schema or disable"
198+ " streaming."
199+ )
188200
189- if isinstance (validator_log .validation_result , PassResult ):
190- chunk = validator_log .validation_result .validated_chunk
191- elif isinstance (validator_log .validation_result , FailResult ):
192- if is_filter or is_refrain :
193- refrain_triggered = True
194- chunk = ""
195- else :
196- chunk = validator_service .perform_correction (
197- validator_log .validation_result ,
198- validator_log .validation_result .validated_chunk ,
199- validator , # type: ignore
200- rechecked_value = None ,
201- ) # type: ignore
201+ if isinstance (validator_log .validation_result , PassResult ):
202+ chunk = validator_log .validation_result .validated_chunk
203+ elif isinstance (
204+ validator_log .validation_result , FailResult
205+ ):
206+ if is_filter or is_refrain :
207+ refrain_triggered = True
208+ chunk = ""
209+ else :
210+ chunk = validator_service .perform_correction (
211+ validator_log .validation_result ,
212+ validator_log .validation_result .validated_chunk ,
213+ validator , # type: ignore
214+ rechecked_value = None ,
215+ ) # type: ignore
202216
203- if validator_log .validator_name not in validation_progress :
204- validation_progress [validator_log .validator_name ] = ""
217+ if validator_log .validator_name not in validation_progress :
218+ validation_progress [validator_log .validator_name ] = ""
205219
206- validation_progress [validator_log .validator_name ] += chunk
207- # if there is an entry for every validator
208- # run a merge and emit a validation outcome
209- if len (validation_progress ) == len (validators ) or len (validators ) == 0 :
210- if refrain_triggered :
211- current = ""
212- else :
213- merge_chunks = []
214- for piece in validation_progress :
215- merge_chunks .append (validation_progress [piece ])
220+ validation_progress [validator_log .validator_name ] += chunk
221+ # if there is an entry for every validator
222+ # run a merge and emit a validation outcome
223+ if (
224+ len (validation_progress ) == len (validators )
225+ or len (validators ) == 0
226+ ):
227+ if refrain_triggered :
228+ current = ""
229+ else :
230+ merge_chunks = []
231+ for piece in validation_progress :
232+ merge_chunks .append (validation_progress [piece ])
216233
217- current = validator_service .multi_merge (fragment , merge_chunks )
234+ current = validator_service .multi_merge (
235+ fragment , merge_chunks
236+ )
218237
219- vo = ValidationOutcome (
220- call_id = call_log .id , # type: ignore
221- raw_llm_output = fragment ,
222- validated_output = current ,
223- validation_passed = True ,
224- )
225- fragment = ""
226- validation_progress = {}
227- refrain_triggered = False
238+ vo = ValidationOutcome (
239+ call_id = call_log .id , # type: ignore
240+ raw_llm_output = fragment ,
241+ validated_output = current ,
242+ validation_passed = True ,
243+ )
244+ fragment = ""
245+ validation_progress = {}
246+ refrain_triggered = False
247+
248+ yield vo
228249
229- yield vo
250+ except StopIteration :
251+ next_exists = False
252+ except StopAsyncIteration :
253+ next_exists = False
254+ except Exception as e :
255+ raise e
256+ finally :
257+ # reset all context vars
258+ for context_var in context_vars .values ():
259+ token = context .run (context_var .set , [])
260+ context .run (context_var .reset , token )
261+ token = context .run (stream_context_vars .set , {})
262+ context .run (stream_context_vars .reset , token )
230263
231264 # if theres anything left merge and emit a chunk
232265 if len (validation_progress ) > 0 :
@@ -242,48 +275,64 @@ async def async_step(
242275 validation_passed = validation_passed ,
243276 )
244277 else :
245- async for chunk in stream_output :
246- chunk_text = self .get_chunk_text (chunk , api )
247- fragment += chunk_text
278+ next_exists = True
279+ while next_exists :
280+ try :
281+ chunk = await anext (stream_output )
282+ chunk_text = self .get_chunk_text (chunk , api )
283+ fragment += chunk_text
248284
249- parsed_fragment , move_to_next = self .parse (
250- fragment , output_schema , verified = verified
251- )
252- if move_to_next :
253- continue
254- validated_fragment = await self .async_validate (
255- iteration ,
256- index ,
257- parsed_fragment ,
258- output_schema ,
259- validate_subschema = True ,
260- context = context ,
261- context_vars = stream_context_vars ,
262- )
263- if isinstance (validated_fragment , SkeletonReAsk ):
264- raise ValueError (
265- "Received fragment schema is an invalid sub-schema "
266- "of the expected output JSON schema."
285+ parsed_fragment , move_to_next = self .parse (
286+ fragment , output_schema , verified = verified
267287 )
268-
269- reasks , valid_op = self .introspect (validated_fragment )
270- if reasks :
271- raise ValueError (
272- "Reasks are not yet supported with streaming. Please "
273- "remove reasks from schema or disable streaming."
288+ if move_to_next :
289+ continue
290+ validated_fragment = await self .async_validate (
291+ iteration ,
292+ index ,
293+ parsed_fragment ,
294+ output_schema ,
295+ validate_subschema = True ,
296+ context = context ,
297+ context_vars = stream_context_vars ,
274298 )
299+ if isinstance (validated_fragment , SkeletonReAsk ):
300+ raise ValueError (
301+ "Received fragment schema is an invalid sub-schema "
302+ "of the expected output JSON schema."
303+ )
275304
276- if self .output_type == OutputTypes .LIST :
277- validation_response = cast (list , validated_fragment )
278- else :
279- validation_response = cast (dict , validated_fragment )
280- yield ValidationOutcome (
281- call_id = call_log .id , # type: ignore
282- raw_llm_output = validated_fragment ,
283- validated_output = chunk_text ,
284- validation_passed = validated_fragment is not None ,
285- )
286- fragment = ""
305+ reasks , valid_op = self .introspect (validated_fragment )
306+ if reasks :
307+ raise ValueError (
308+ "Reasks are not yet supported with streaming. Please "
309+ "remove reasks from schema or disable streaming."
310+ )
311+
312+ if self .output_type == OutputTypes .LIST :
313+ validation_response = cast (list , validated_fragment )
314+ else :
315+ validation_response = cast (dict , validated_fragment )
316+ yield ValidationOutcome (
317+ call_id = call_log .id , # type: ignore
318+ raw_llm_output = fragment ,
319+ validated_output = validated_fragment ,
320+ validation_passed = validated_fragment is not None ,
321+ )
322+ fragment = ""
323+ except StopIteration :
324+ next_exists = False
325+ except StopAsyncIteration :
326+ next_exists = False
327+ except Exception as e :
328+ raise e
329+ finally :
330+ # reset all context vars
331+ for context_var in context_vars .values ():
332+ token = context .run (context_var .set , [])
333+ context .run (context_var .reset , token )
334+ token = context .run (stream_context_vars .set , {})
335+ context .run (stream_context_vars .reset , token )
287336
288337 iteration .outputs .raw_output = fragment
289338 # FIXME: Handle case where parsing continuously fails/is a reask
0 commit comments