Skip to content

Commit 16b5b27

Browse files
authored
Merge branch 'main' into notebook-fixes
2 parents 853d9b1 + 5528c2c commit 16b5b27

File tree

9 files changed

+2227
-52
lines changed

9 files changed

+2227
-52
lines changed

docs/examples/streaming.ipynb

Lines changed: 1230 additions & 0 deletions
Large diffs are not rendered by default.

guardrails/llm_providers.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
1+
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast
22

33
from pydantic import BaseModel
44

@@ -166,6 +166,7 @@ def _invoke_llm(
166166
api_key = kwargs.pop("api_key")
167167
else:
168168
api_key = None
169+
169170
client = OpenAIClient(api_key=api_key)
170171
return client.create_chat_completion(
171172
model=model,
@@ -256,9 +257,34 @@ def _invoke_llm(self, *args, **kwargs) -> LLMResponse:
256257
)
257258
```
258259
"""
259-
return LLMResponse(
260-
output=self.llm_api(*args, **kwargs),
261-
)
260+
# Get the response from the callable
261+
# The LLM response should either be a
262+
# string or an generator object of strings
263+
llm_response = self.llm_api(*args, **kwargs)
264+
265+
# Check if kwargs stream is passed in
266+
if kwargs.get("stream", None) in [None, False]:
267+
# If stream is not defined or is set to False,
268+
# return default behavior
269+
# Strongly type the response as a string
270+
llm_response = cast(str, llm_response)
271+
return LLMResponse(
272+
output=llm_response,
273+
)
274+
else:
275+
# If stream is defined and set to True,
276+
# the callable returns a generator object
277+
complete_output = ""
278+
279+
# Strongly type the response as an iterable of strings
280+
llm_response = cast(Iterable[str], llm_response)
281+
for response in llm_response:
282+
complete_output += response
283+
284+
# Return the LLMResponse
285+
return LLMResponse(
286+
output=complete_output,
287+
)
262288

263289

264290
def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
@@ -405,6 +431,7 @@ async def invoke_llm(
405431
api_key = kwargs.pop("api_key")
406432
else:
407433
api_key = None
434+
408435
aclient = AsyncOpenAIClient(api_key=api_key)
409436
return await aclient.create_chat_completion(
410437
model=model,
@@ -481,7 +508,6 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:
481508
def get_async_llm_ask(
482509
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
483510
) -> AsyncPromptCallableBase:
484-
485511
# these only work with openai v0 (None otherwise)
486512
if llm_api == get_static_openai_acreate_func():
487513
return AsyncOpenAICallable(*args, **kwargs)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Dict, List
2+
3+
import tiktoken
4+
5+
6+
def num_tokens_from_string(text: str, model_name: str) -> int:
7+
"""Returns the number of tokens in a text string.
8+
9+
Supported for OpenAI models only. This is a helper function
10+
that is required when OpenAI's `stream` parameter is set to `True`,
11+
because OpenAI does not return the number of tokens in that case.
12+
Requires the `tiktoken` package to be installed.
13+
14+
Args:
15+
text (str): The text string to count the number of tokens in.
16+
model_name (str): The name of the OpenAI model to use.
17+
18+
Returns:
19+
num_tokens (int): The number of tokens in the text string.
20+
"""
21+
encoding = tiktoken.encoding_for_model(model_name)
22+
num_tokens = len(encoding.encode(text))
23+
return num_tokens
24+
25+
26+
def num_tokens_from_messages(
27+
messages: List[Dict[str, str]], model: str = "gpt-3.5-turbo-0613"
28+
) -> int:
29+
"""Return the number of tokens used by a list of messages."""
30+
try:
31+
encoding = tiktoken.encoding_for_model(model)
32+
except KeyError:
33+
print("Warning: model not found. Using cl100k_base encoding.")
34+
encoding = tiktoken.get_encoding("cl100k_base")
35+
if model in {
36+
"gpt-3.5-turbo-0613",
37+
"gpt-3.5-turbo-16k-0613",
38+
"gpt-4-0314",
39+
"gpt-4-32k-0314",
40+
"gpt-4-0613",
41+
"gpt-4-32k-0613",
42+
}:
43+
tokens_per_message = 3
44+
tokens_per_name = 1
45+
elif model == "gpt-3.5-turbo-0301":
46+
tokens_per_message = (
47+
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
48+
)
49+
tokens_per_name = -1 # if there's a name, the role is omitted
50+
elif "gpt-3.5-turbo" in model:
51+
print(
52+
"""Warning: gpt-3.5-turbo may update over time.
53+
Returning num tokens assuming gpt-3.5-turbo-0613."""
54+
)
55+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
56+
elif "gpt-4" in model:
57+
print(
58+
"""Warning: gpt-4 may update over time.
59+
Returning num tokens assuming gpt-4-0613."""
60+
)
61+
return num_tokens_from_messages(messages, model="gpt-4-0613")
62+
else:
63+
raise NotImplementedError(
64+
f"""num_tokens_from_messages() is not implemented for model {model}.
65+
See https://github.com/openai/openai-python/blob/main/chatml.md for
66+
information on how messages are converted to tokens."""
67+
)
68+
69+
num_tokens = 0
70+
for message in messages:
71+
num_tokens += tokens_per_message
72+
for key, value in message.items():
73+
num_tokens += len(encoding.encode(value))
74+
if key == "name":
75+
num_tokens += tokens_per_name
76+
77+
# every reply is primed with <|start|>assistant<|message|>
78+
num_tokens += 3
79+
return num_tokens

0 commit comments

Comments
 (0)