Skip to content

Commit aefd640

Browse files
committed
reset context vars at the end of the stream
1 parent 5c00824 commit aefd640

File tree

1 file changed

+164
-115
lines changed

1 file changed

+164
-115
lines changed

guardrails/run/async_stream_runner.py

Lines changed: 164 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from contextvars import ContextVar, copy_context
2+
import sys
23
from typing import (
34
Any,
45
AsyncIterator,
@@ -28,6 +29,10 @@
2829
)
2930

3031

32+
if sys.version_info.minor < 10:
33+
from guardrails.utils.polyfills import anext
34+
35+
3136
class AsyncStreamRunner(AsyncRunner, StreamRunner):
3237
# @async_trace_stream(name="/reasks", origin="AsyncStreamRunner.async_run")
3338
async def async_run(
@@ -137,96 +142,124 @@ async def async_step(
137142

138143
if self.output_type == OutputTypes.STRING:
139144
validator_service = AsyncValidatorService(self.disable_tracer)
140-
async for chunk in stream_output:
141-
chunk_text = self.get_chunk_text(chunk, api)
142-
_ = self.is_last_chunk(chunk, api)
143145

144-
fragment += chunk_text
146+
next_exists = True
147+
while next_exists:
148+
try:
149+
chunk = await anext(stream_output)
150+
chunk_text = self.get_chunk_text(chunk, api)
151+
_ = self.is_last_chunk(chunk, api)
145152

146-
results = await validator_service.async_partial_validate(
147-
chunk_text,
148-
self.metadata,
149-
self.validation_map,
150-
iteration,
151-
"$",
152-
"$",
153-
True,
154-
context=context,
155-
context_vars=stream_context_vars,
156-
)
157-
validators = self.validation_map.get("$", [])
153+
fragment += chunk_text
158154

159-
# collect the result validated_chunk into validation progress
160-
# per validator
161-
for result in results:
162-
validator_log = result.validator_logs # type: ignore
163-
validator = next(
164-
filter(
165-
lambda x: x.rail_alias == validator_log.registered_name,
166-
validators,
167-
),
168-
None,
155+
results = await validator_service.async_partial_validate(
156+
chunk_text,
157+
self.metadata,
158+
self.validation_map,
159+
iteration,
160+
"$",
161+
"$",
162+
True,
163+
context=context,
164+
context_vars=stream_context_vars,
169165
)
170-
if (
171-
validator_log.validation_result
172-
and validator_log.validation_result.validated_chunk
173-
):
174-
is_filter = validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore
175-
is_refrain = (
176-
validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore
166+
validators = self.validation_map.get("$", [])
167+
168+
# collect the result validated_chunk into validation progress
169+
# per validator
170+
for result in results:
171+
validator_log = result.validator_logs # type: ignore
172+
validator = next(
173+
filter(
174+
lambda x: x.rail_alias == validator_log.registered_name,
175+
validators,
176+
),
177+
None,
177178
)
178-
if validator_log.validation_result.outcome == "fail":
179-
validation_passed = False
180-
reasks, valid_op = self.introspect(
179+
if (
181180
validator_log.validation_result
182-
)
183-
if reasks:
184-
raise ValueError(
185-
"Reasks are not yet supported with streaming. Please "
186-
"remove reasks from schema or disable streaming."
181+
and validator_log.validation_result.validated_chunk
182+
):
183+
is_filter = (
184+
validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore
185+
)
186+
is_refrain = (
187+
validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore
188+
)
189+
if validator_log.validation_result.outcome == "fail":
190+
validation_passed = False
191+
reasks, valid_op = self.introspect(
192+
validator_log.validation_result
187193
)
194+
if reasks:
195+
raise ValueError(
196+
"Reasks are not yet supported with streaming. "
197+
"Please remove reasks from schema or disable"
198+
" streaming."
199+
)
188200

189-
if isinstance(validator_log.validation_result, PassResult):
190-
chunk = validator_log.validation_result.validated_chunk
191-
elif isinstance(validator_log.validation_result, FailResult):
192-
if is_filter or is_refrain:
193-
refrain_triggered = True
194-
chunk = ""
195-
else:
196-
chunk = validator_service.perform_correction(
197-
validator_log.validation_result,
198-
validator_log.validation_result.validated_chunk,
199-
validator, # type: ignore
200-
rechecked_value=None,
201-
) # type: ignore
201+
if isinstance(validator_log.validation_result, PassResult):
202+
chunk = validator_log.validation_result.validated_chunk
203+
elif isinstance(
204+
validator_log.validation_result, FailResult
205+
):
206+
if is_filter or is_refrain:
207+
refrain_triggered = True
208+
chunk = ""
209+
else:
210+
chunk = validator_service.perform_correction(
211+
validator_log.validation_result,
212+
validator_log.validation_result.validated_chunk,
213+
validator, # type: ignore
214+
rechecked_value=None,
215+
) # type: ignore
202216

203-
if validator_log.validator_name not in validation_progress:
204-
validation_progress[validator_log.validator_name] = ""
217+
if validator_log.validator_name not in validation_progress:
218+
validation_progress[validator_log.validator_name] = ""
205219

206-
validation_progress[validator_log.validator_name] += chunk
207-
# if there is an entry for every validator
208-
# run a merge and emit a validation outcome
209-
if len(validation_progress) == len(validators) or len(validators) == 0:
210-
if refrain_triggered:
211-
current = ""
212-
else:
213-
merge_chunks = []
214-
for piece in validation_progress:
215-
merge_chunks.append(validation_progress[piece])
220+
validation_progress[validator_log.validator_name] += chunk
221+
# if there is an entry for every validator
222+
# run a merge and emit a validation outcome
223+
if (
224+
len(validation_progress) == len(validators)
225+
or len(validators) == 0
226+
):
227+
if refrain_triggered:
228+
current = ""
229+
else:
230+
merge_chunks = []
231+
for piece in validation_progress:
232+
merge_chunks.append(validation_progress[piece])
216233

217-
current = validator_service.multi_merge(fragment, merge_chunks)
234+
current = validator_service.multi_merge(
235+
fragment, merge_chunks
236+
)
218237

219-
vo = ValidationOutcome(
220-
call_id=call_log.id, # type: ignore
221-
raw_llm_output=fragment,
222-
validated_output=current,
223-
validation_passed=True,
224-
)
225-
fragment = ""
226-
validation_progress = {}
227-
refrain_triggered = False
238+
vo = ValidationOutcome(
239+
call_id=call_log.id, # type: ignore
240+
raw_llm_output=fragment,
241+
validated_output=current,
242+
validation_passed=True,
243+
)
244+
fragment = ""
245+
validation_progress = {}
246+
refrain_triggered = False
247+
248+
yield vo
228249

229-
yield vo
250+
except StopIteration:
251+
next_exists = False
252+
except StopAsyncIteration:
253+
next_exists = False
254+
except Exception as e:
255+
raise e
256+
finally:
257+
# reset all context vars
258+
for context_var in context_vars.values():
259+
token = context.run(context_var.set, [])
260+
context.run(context_var.reset, token)
261+
token = context.run(stream_context_vars.set, {})
262+
context.run(stream_context_vars.reset, token)
230263

231264
# if theres anything left merge and emit a chunk
232265
if len(validation_progress) > 0:
@@ -242,48 +275,64 @@ async def async_step(
242275
validation_passed=validation_passed,
243276
)
244277
else:
245-
async for chunk in stream_output:
246-
chunk_text = self.get_chunk_text(chunk, api)
247-
fragment += chunk_text
278+
next_exists = True
279+
while next_exists:
280+
try:
281+
chunk = await anext(stream_output)
282+
chunk_text = self.get_chunk_text(chunk, api)
283+
fragment += chunk_text
248284

249-
parsed_fragment, move_to_next = self.parse(
250-
fragment, output_schema, verified=verified
251-
)
252-
if move_to_next:
253-
continue
254-
validated_fragment = await self.async_validate(
255-
iteration,
256-
index,
257-
parsed_fragment,
258-
output_schema,
259-
validate_subschema=True,
260-
context=context,
261-
context_vars=stream_context_vars,
262-
)
263-
if isinstance(validated_fragment, SkeletonReAsk):
264-
raise ValueError(
265-
"Received fragment schema is an invalid sub-schema "
266-
"of the expected output JSON schema."
285+
parsed_fragment, move_to_next = self.parse(
286+
fragment, output_schema, verified=verified
267287
)
268-
269-
reasks, valid_op = self.introspect(validated_fragment)
270-
if reasks:
271-
raise ValueError(
272-
"Reasks are not yet supported with streaming. Please "
273-
"remove reasks from schema or disable streaming."
288+
if move_to_next:
289+
continue
290+
validated_fragment = await self.async_validate(
291+
iteration,
292+
index,
293+
parsed_fragment,
294+
output_schema,
295+
validate_subschema=True,
296+
context=context,
297+
context_vars=stream_context_vars,
274298
)
299+
if isinstance(validated_fragment, SkeletonReAsk):
300+
raise ValueError(
301+
"Received fragment schema is an invalid sub-schema "
302+
"of the expected output JSON schema."
303+
)
275304

276-
if self.output_type == OutputTypes.LIST:
277-
validation_response = cast(list, validated_fragment)
278-
else:
279-
validation_response = cast(dict, validated_fragment)
280-
yield ValidationOutcome(
281-
call_id=call_log.id, # type: ignore
282-
raw_llm_output=validated_fragment,
283-
validated_output=chunk_text,
284-
validation_passed=validated_fragment is not None,
285-
)
286-
fragment = ""
305+
reasks, valid_op = self.introspect(validated_fragment)
306+
if reasks:
307+
raise ValueError(
308+
"Reasks are not yet supported with streaming. Please "
309+
"remove reasks from schema or disable streaming."
310+
)
311+
312+
if self.output_type == OutputTypes.LIST:
313+
validation_response = cast(list, validated_fragment)
314+
else:
315+
validation_response = cast(dict, validated_fragment)
316+
yield ValidationOutcome(
317+
call_id=call_log.id, # type: ignore
318+
raw_llm_output=fragment,
319+
validated_output=validated_fragment,
320+
validation_passed=validated_fragment is not None,
321+
)
322+
fragment = ""
323+
except StopIteration:
324+
next_exists = False
325+
except StopAsyncIteration:
326+
next_exists = False
327+
except Exception as e:
328+
raise e
329+
finally:
330+
# reset all context vars
331+
for context_var in context_vars.values():
332+
token = context.run(context_var.set, [])
333+
context.run(context_var.reset, token)
334+
token = context.run(stream_context_vars.set, {})
335+
context.run(stream_context_vars.reset, token)
287336

288337
iteration.outputs.raw_output = fragment
289338
# FIXME: Handle case where parsing continuously fails/is a reask

0 commit comments

Comments
 (0)