Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions libs/partners/perplexity/langchain_perplexity/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata, subtract_usage
from langchain_core.messages.ai import (
OutputTokenDetails,
UsageMetadata,
subtract_usage,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand All @@ -49,13 +53,28 @@ def _is_pydantic_class(obj: Any) -> bool:


def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
"""Create UsageMetadata from Perplexity token usage data.

Args:
token_usage: Dictionary containing token usage information from Perplexity API.

Returns:
UsageMetadata with properly structured token counts and details.
"""
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)

# Build output_token_details for Perplexity-specific fields
output_token_details: OutputTokenDetails = {}
output_token_details["reasoning"] = token_usage.get("reasoning_tokens", 0)
output_token_details["citation_tokens"] = token_usage.get("citation_tokens", 0) # type: ignore[typeddict-unknown-key]

return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
output_token_details=output_token_details,
)


Expand Down Expand Up @@ -301,6 +320,7 @@ def _stream(
prev_total_usage: UsageMetadata | None = None

added_model_name: bool = False
added_search_queries: bool = False
for chunk in stream_resp:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand Down Expand Up @@ -332,6 +352,13 @@ def _stream(
generation_info["model_name"] = model_name
added_model_name = True

# Add num_search_queries to generation_info if present
if total_usage := chunk.get("usage"):
if num_search_queries := total_usage.get("num_search_queries"):
if not added_search_queries:
generation_info["num_search_queries"] = num_search_queries
added_search_queries = True

chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
Expand Down Expand Up @@ -369,20 +396,29 @@ def _generate(
params = {**params, **kwargs}
response = self.client.chat.completions.create(messages=message_dicts, **params)
if usage := getattr(response, "usage", None):
usage_metadata = _create_usage_metadata(usage.model_dump())
usage_dict = usage.model_dump()
usage_metadata = _create_usage_metadata(usage_dict)
else:
usage_metadata = None
usage_dict = {}

additional_kwargs = {}
for attr in ["citations", "images", "related_questions", "search_results"]:
if hasattr(response, attr):
additional_kwargs[attr] = getattr(response, attr)

# Build response_metadata with model_name and num_search_queries
response_metadata: dict[str, Any] = {
"model_name": getattr(response, "model", self.model)
}
if num_search_queries := usage_dict.get("num_search_queries"):
response_metadata["num_search_queries"] = num_search_queries

message = AIMessage(
content=response.choices[0].message.content,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata,
response_metadata={"model_name": getattr(response, "model", self.model)},
response_metadata=response_metadata,
)
return ChatResult(generations=[ChatGeneration(message=message)])

Expand Down