diff --git a/libs/partners/ollama/test_output_parser_issue.py b/libs/partners/ollama/test_output_parser_issue.py new file mode 100644 index 0000000000000..0e7b29bc359b0 --- /dev/null +++ b/libs/partners/ollama/test_output_parser_issue.py @@ -0,0 +1,93 @@ +""" +Python reproduction of OutputParserException issue with Ollama +Original issue from JavaScript: OutputParserException with empty text parsing +""" + +from langchain_ollama import ChatOllama +from langchain_core.prompts import ChatPromptTemplate +from pydantic import BaseModel, Field +from typing import Literal + +# Initialize Ollama (equivalent to JavaScript version) +llm = ChatOllama( + model="llama3.2:3b", + temperature=0, + base_url="http://localhost:11434" +) + +# Define the classification schema (equivalent to Zod schema) +class ClassificationSchema(BaseModel): + """Extract sentiment, aggressiveness, and language from text""" + + sentiment: Literal["happy", "neutral", "sad"] = Field( + description="The sentiment of the text" + ) + aggressiveness: int = Field( + description="Describes how aggressive the statement is on a scale from 1 to 5. The higher the number the more aggressive" + ) + language: Literal["spanish", "english", "french", "german", "italian"] = Field( + description="The language the text is written in" + ) + + +# Create the tagging prompt +tagging_prompt = ChatPromptTemplate.from_template( + """Extract the desired information from the following passage. + +Passage: +{input} +""" +) + +# Create LLM with structured output +llm_with_structured_output = llm.with_structured_output( + ClassificationSchema, + method="function_calling" # or try "json_schema" +) + +# Test input (Spanish text) +test_input = "Estoy increiblemente contento de haberte conocido! Creo que seremos muy buenos amigos!" + +print("Testing Ollama structured output with Spanish text...") +print(f"Input: {test_input}\n") + +try: + # Format the prompt + prompt = tagging_prompt.invoke({"input": test_input}) + + # Get structured output + result = llm_with_structured_output.invoke(prompt) + + print("✓ SUCCESS!") + print(f"Result: {result}") + print(f"\nParsed values:") + print(f" - Sentiment: {result.sentiment}") + print(f" - Aggressiveness: {result.aggressiveness}") + print(f" - Language: {result.language}") + +except Exception as e: + print(f"✗ FAILED: {type(e).__name__}") + print(f"Error: {e}") + print("\nThis is the OutputParserException the fix addresses!") + + +# Additional test to show the fix working +print("\n" + "="*60) +print("Testing with different methods:") +print("="*60) + +methods = ["function_calling", "json_schema"] + +for method in methods: + print(f"\nMethod: {method}") + try: + llm_test = llm.with_structured_output( + ClassificationSchema, + method=method + ) + result = llm_test.invoke( + tagging_prompt.invoke({"input": test_input}) + ) + print(f" ✓ Success: {result}") + except Exception as e: + print(f" ✗ Failed: {type(e).__name__}: {str(e)[:100]}") diff --git a/libs/partners/sarvam/.gitignore b/libs/partners/sarvam/.gitignore new file mode 100644 index 0000000000000..4e35a9d62b767 --- /dev/null +++ b/libs/partners/sarvam/.gitignore @@ -0,0 +1,126 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Poetry +poetry.lock diff --git a/libs/partners/sarvam/.mypy.ini b/libs/partners/sarvam/.mypy.ini new file mode 100644 index 0000000000000..23693a3d2e0e6 --- /dev/null +++ b/libs/partners/sarvam/.mypy.ini @@ -0,0 +1,9 @@ +[mypy] +disallow_untyped_defs = True +ignore_missing_imports = True +exclude = tests/ +explicit_package_bases = True +namespace_packages = True + +[mypy-tests.*] +disallow_untyped_defs = False diff --git a/libs/partners/sarvam/.ruff.toml b/libs/partners/sarvam/.ruff.toml new file mode 100644 index 0000000000000..8e46197545ef7 --- /dev/null +++ b/libs/partners/sarvam/.ruff.toml @@ -0,0 +1,19 @@ +[lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "T201", # print statements + "UP", # pyupgrade +] + +ignore = [ + "E501", # line too long (handled by formatter) +] + +[lint.per-file-ignores] +"tests/*" = ["T201"] # Allow print in tests +"scripts/*" = ["T201"] # Allow print in scripts + +[lint.isort] +known-first-party = ["langchain_sarvam"] diff --git a/libs/partners/sarvam/CONTRIBUTING.md b/libs/partners/sarvam/CONTRIBUTING.md new file mode 100644 index 0000000000000..c3e3ccd59cd34 --- /dev/null +++ b/libs/partners/sarvam/CONTRIBUTING.md @@ -0,0 +1,142 @@ +# Contributing to langchain-sarvam + +Thank you for your interest in contributing to langchain-sarvam! This document provides guidelines for contributing to this package. + +## Development Setup + +1. Clone the repository and navigate to the package directory: +```bash +cd libs/partners/sarvam +``` + +2. Install dependencies using Poetry: +```bash +poetry install --with test,lint,typing,dev +``` + +3. Set up your Sarvam API key for testing: +```bash +export SARVAM_API_KEY="your-api-key" +``` + +## Running Tests + +### Unit Tests +```bash +make test +# or +poetry run pytest tests/unit_tests +``` + +### Integration Tests +Integration tests require a valid Sarvam API key: +```bash +make integration_tests +# or +poetry run pytest tests/integration_tests +``` + +### All Tests +```bash +poetry run pytest tests/ +``` + +## Code Quality + +### Linting +```bash +make lint +``` + +This will run: +- `ruff` for code style checking +- `mypy` for type checking + +### Formatting +```bash +make format +``` + +This will automatically format your code using `ruff`. + +### Spell Checking +```bash +make spell_check +# To automatically fix spelling issues: +make spell_fix +``` + +## Before Submitting a PR + +1. **Run all tests**: Ensure all unit and integration tests pass +```bash +poetry run pytest tests/ +``` + +2. **Run linting**: Fix any linting errors +```bash +make lint +make format +``` + +3. **Check imports**: Verify imports are correct +```bash +make check_imports +``` + +4. **Update documentation**: If you've added new features, update: + - README.md + - Docstrings in the code + - Example notebooks if applicable + +## Code Style Guidelines + +- Follow PEP 8 style guidelines +- Use type hints for all function parameters and return values +- Write descriptive docstrings for all public methods and classes +- Keep functions focused and single-purpose +- Add comments for complex logic + +## Testing Guidelines + +- Write unit tests for all new functionality +- Ensure tests are isolated and don't depend on external services (use mocking) +- Integration tests should be marked with `@pytest.mark.scheduled` +- Test edge cases and error conditions + +## Pull Request Process + +1. Fork the repository and create a new branch for your feature +2. Make your changes following the guidelines above +3. Ensure all tests pass and code is properly formatted +4. Update documentation as needed +5. Submit a pull request with a clear description of changes + +## Common Issues + +### CI/CD Failures + +If you encounter CI failures: + +1. **Linting errors** (`lint` job failing): + - Run `make lint` locally + - Fix any `ruff` or `mypy` errors + - Run `make format` to auto-format code + +2. **Test failures**: + - Run tests locally: `poetry run pytest` + - Check if API key is properly set for integration tests + - Review test output for specific failures + +3. **Import errors**: + - Verify all imports are from allowed packages + - Run `make check_imports` + +## Questions or Problems? + +If you have questions or run into issues: +- Check existing GitHub issues +- Create a new issue with a clear description +- Join the LangChain Discord community + +Thank you for contributing! diff --git a/libs/partners/sarvam/LICENSE b/libs/partners/sarvam/LICENSE new file mode 100644 index 0000000000000..fc0602feecdd6 --- /dev/null +++ b/libs/partners/sarvam/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 LangChain, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/libs/partners/sarvam/Makefile b/libs/partners/sarvam/Makefile new file mode 100644 index 0000000000000..e4ae4392a3723 --- /dev/null +++ b/libs/partners/sarvam/Makefile @@ -0,0 +1,114 @@ +.PHONY: all format lint test tests integration_tests help extended_tests clean install + +# Default target +all: help + +###################### +# INSTALLATION +###################### + +install: + pip install -e ".[test]" + +###################### +# TESTING +###################### + +test tests: + pytest tests/unit_tests -v \ + --cov=langchain_sarvam \ + --cov-report=term-missing \ + --cov-report=xml \ + --junitxml=junit/test-results.xml + +integration_tests: + pytest tests/integration_tests -v \ + --cov=langchain_sarvam \ + --cov-report=term-missing \ + --cov-report=xml + +extended_tests: + $(MAKE) test + $(MAKE) integration_tests + +###################### +# LINTING AND FORMATTING +###################### + +# Define a variable for Python and notebook files. +PYTHON_FILES=. +MYPY_CACHE=.mypy_cache +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/sarvam --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') +lint_package: PYTHON_FILES=langchain_sarvam +lint_tests: PYTHON_FILES=tests +lint_tests: MYPY_CACHE=.mypy_cache_test + +lint lint_diff lint_package lint_tests: + ruff check $(PYTHON_FILES) + ruff format --check $(PYTHON_FILES) + mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + +format format_diff: + ruff format $(PYTHON_FILES) + ruff check --select I --fix $(PYTHON_FILES) + +spell_check: + codespell --toml pyproject.toml + +spell_fix: + codespell --toml pyproject.toml -w + +check_imports: $(shell find langchain_sarvam -name '*.py') + python -m scripts.check_imports $^ + +###################### +# CLEANING +###################### + +clean: + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -type d -name "*.egg-info" -exec rm -rf {} + + find . -type f -name "*.pyc" -delete + find . -type f -name "*.pyo" -delete + find . -type f -name ".coverage" -delete + find . -type d -name ".pytest_cache" -exec rm -rf {} + + find . -type d -name ".ruff_cache" -exec rm -rf {} + + find . -type d -name ".mypy_cache*" -exec rm -rf {} + + find . -type d -name "junit" -exec rm -rf {} + + rm -rf build/ + rm -rf dist/ + rm -rf .coverage* + rm -rf coverage.xml + rm -rf htmlcov/ + +###################### +# BUILD +###################### + +build: + python -m build + +publish: + python -m twine upload dist/* + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'install - install package in editable mode with test dependencies' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'test - run unit tests' + @echo 'tests - run unit tests (alias for test)' + @echo 'integration_tests - run integration tests' + @echo 'extended_tests - run all tests' + @echo 'clean - clean all build and test artifacts' + @echo 'build - build package' + @echo 'publish - publish package to PyPI' + @echo 'spell_check - run codespell on the project' + @echo 'spell_fix - run codespell on the project and fix issues' + @echo 'check_imports - check imports' + @echo 'help - print this help message' diff --git a/libs/partners/sarvam/README.md b/libs/partners/sarvam/README.md new file mode 100644 index 0000000000000..94ec2eb1033c8 --- /dev/null +++ b/libs/partners/sarvam/README.md @@ -0,0 +1,150 @@ +# langchain-sarvam + +This package contains the LangChain integration with [Sarvam AI](https://www.sarvam.ai/). + +Sarvam AI provides LLMs optimized for Indian languages understanding and efficiency, offering strong multilingual capabilities especially for Indic and low-resource languages. + +## Installation + +```bash +pip install -U langchain-sarvam +``` + +## Setup + +To use the Sarvam AI models, you'll need to obtain an API key from Sarvam AI and set it as an environment variable: + +```bash +export SARVAM_API_KEY="your-api-key-here" +``` + +## Chat Models + +### Basic Usage + +```python +from langchain_sarvam import ChatSarvam +from langchain_core.messages import HumanMessage, SystemMessage + +# Initialize the model +chat = ChatSarvam( + model="sarvam-1", + temperature=0.7, + max_tokens=1024, +) + +# Create messages +messages = [ + SystemMessage(content="You are a helpful AI assistant."), + HumanMessage(content="What is the capital of India?"), +] + +# Get response +response = chat.invoke(messages) +print(response.content) +``` + +### Streaming + +```python +from langchain_sarvam import ChatSarvam + +chat = ChatSarvam(model="sarvam-1") + +for chunk in chat.stream("Tell me a short story about India"): + print(chunk.content, end="", flush=True) +``` + +### Async Usage + +```python +import asyncio +from langchain_sarvam import ChatSarvam +from langchain_core.messages import HumanMessage + +async def main(): + chat = ChatSarvam(model="sarvam-1") + response = await chat.ainvoke([HumanMessage(content="Hello!")]) + print(response.content) + +asyncio.run(main()) +``` + +### Configuration Options + +The `ChatSarvam` class supports various configuration parameters: + +- `model`: Model identifier (default: "sarvam-1") +- `temperature`: Controls randomness (0.0 to 2.0, default: 0.7) +- `max_tokens`: Maximum tokens to generate +- `top_p`: Nucleus sampling parameter +- `frequency_penalty`: Penalize token frequency (-2.0 to 2.0) +- `presence_penalty`: Penalize token presence (-2.0 to 2.0) +- `timeout`: Request timeout in seconds +- `max_retries`: Maximum retry attempts (default: 2) + +```python +chat = ChatSarvam( + model="sarvam-1", + temperature=0.9, + max_tokens=2048, + top_p=0.95, + timeout=30.0, +) +``` + +## Use Cases + +Sarvam AI models are particularly well-suited for: + +- **Multilingual Applications**: Strong support for Indian languages +- **RAG Pipelines**: Efficient retrieval-augmented generation +- **Conversational AI**: Building chatbots and assistants for Indian markets +- **Low-Resource Languages**: Working with languages that have limited training data + +## Integration with LangChain Components + +### Using with Chains + +```python +from langchain_sarvam import ChatSarvam +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser + +chat = ChatSarvam(model="sarvam-1") +prompt = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful assistant."), + ("human", "{input}"), +]) + +chain = prompt | chat | StrOutputParser() +result = chain.invoke({"input": "What is AI?"}) +print(result) +``` + +### Using with Agents + +```python +from langchain_sarvam import ChatSarvam +from langchain.agents import AgentExecutor, create_react_agent +from langchain_core.tools import Tool + +chat = ChatSarvam(model="sarvam-1", temperature=0) + +# Define your tools and create an agent +# ... (agent setup code) +``` + +## API Reference + +For detailed API documentation, please refer to the [ChatSarvam API reference](https://python.langchain.com/api_reference/sarvam/chat_models.html). + +## Support + +For issues, questions, or contributions related to this integration: +- **LangChain Issues**: [GitHub Issues](https://github.com/langchain-ai/langchain/issues) +- **Sarvam AI Support**: Visit [Sarvam AI website](https://www.sarvam.ai/) + +## License + +This integration is released under the MIT License. diff --git a/libs/partners/sarvam/__init__.py b/libs/partners/sarvam/__init__.py new file mode 100644 index 0000000000000..46b6a7dad8ef5 --- /dev/null +++ b/libs/partners/sarvam/__init__.py @@ -0,0 +1,3 @@ +from langchain_sarvam.chat_models import ChatSarvam + +__all__ = ["ChatSarvam"] diff --git a/libs/partners/sarvam/junit/test-results.xml b/libs/partners/sarvam/junit/test-results.xml new file mode 100644 index 0000000000000..ab2df3e26d92d --- /dev/null +++ b/libs/partners/sarvam/junit/test-results.xml @@ -0,0 +1 @@ +D:\langchain\venv\Lib\site-packages\langchain_tests\unit_tests\chat_models.py:949: init_from_env_params not specified.D:\langchain\venv\Lib\site-packages\langchain_tests\unit_tests\chat_models.py:1109: Model is not serializable. \ No newline at end of file diff --git a/libs/partners/sarvam/langchain_sarvam/__init__.py b/libs/partners/sarvam/langchain_sarvam/__init__.py new file mode 100644 index 0000000000000..46b6a7dad8ef5 --- /dev/null +++ b/libs/partners/sarvam/langchain_sarvam/__init__.py @@ -0,0 +1,3 @@ +from langchain_sarvam.chat_models import ChatSarvam + +__all__ = ["ChatSarvam"] diff --git a/libs/partners/sarvam/langchain_sarvam/chat_models.py b/libs/partners/sarvam/langchain_sarvam/chat_models.py new file mode 100644 index 0000000000000..ce669390812be --- /dev/null +++ b/libs/partners/sarvam/langchain_sarvam/chat_models.py @@ -0,0 +1,355 @@ +"""Sarvam chat model integration.""" + +from __future__ import annotations + +import logging +from typing import Any, Iterator + +import requests +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.utils import ( + convert_to_secret_str, + get_from_dict_or_env, +) +from pydantic import Field, SecretStr, model_validator + +logger = logging.getLogger(__name__) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + """Convert a LangChain message to a Sarvam API message dict.""" + message_dict: dict[str, Any] = {} + + if isinstance(message, ChatMessage): + message_dict["role"] = message.role + elif isinstance(message, HumanMessage): + message_dict["role"] = "user" + elif isinstance(message, AIMessage): + message_dict["role"] = "assistant" + elif isinstance(message, SystemMessage): + message_dict["role"] = "system" + elif isinstance(message, ToolMessage): + message_dict["role"] = "tool" + else: + raise ValueError(f"Got unknown message type: {message}") + + message_dict["content"] = message.content + return message_dict + + +def _convert_dict_to_message(response: dict) -> BaseMessage: + """Convert a Sarvam API response to a LangChain message.""" + role = response.get("role", "assistant") + content = response.get("content", "") + + if role == "assistant": + return AIMessage(content=content) + elif role == "user": + return HumanMessage(content=content) + elif role == "system": + return SystemMessage(content=content) + else: + return ChatMessage(content=content, role=role) + + +class ChatSarvam(BaseChatModel): + """Sarvam AI chat model integration. + + Setup: + Install ``langchain-sarvam`` and set environment variable ``SARVAM_API_KEY``. + + .. code-block:: bash + + pip install -U langchain-sarvam + export SARVAM_API_KEY="your-api-key" + + Key init args — completion params: + model: str + Name of Sarvam model to use. + temperature: float + Sampling temperature. + max_tokens: int | None + Max number of tokens to generate. + + Key init args — client params: + api_key: SecretStr | None + Sarvam API key. If not provided, will read from SARVAM_API_KEY env var. + base_url: str + Base URL for Sarvam API. + + Instantiate: + .. code-block:: python + + from langchain_sarvam import ChatSarvam + + llm = ChatSarvam( + model="sarvam-m", + temperature=0.7, + max_tokens=1024, + # api_key="your-api-key", + ) + + Invoke: + .. code-block:: python + + messages = [ + ("system", "You are a helpful assistant."), + ("human", "What is the capital of France?"), + ] + llm.invoke(messages) + + .. code-block:: python + + AIMessage(content='The capital of France is Paris.') + + Stream: + .. code-block:: python + + for chunk in llm.stream(messages): + print(chunk.content, end="", flush=True) + + .. code-block:: python + + The capital of France is Paris. + + Async: + .. code-block:: python + + await llm.ainvoke(messages) + + .. code-block:: python + + AIMessage(content='The capital of France is Paris.') + + """ + + model: str = Field(default="sarvam-m") + """Model name to use.""" + + temperature: float = Field(default=0.7, ge=0.0, le=2.0) + """Sampling temperature.""" + + max_tokens: int | None = Field(default=None) + """Maximum number of tokens to generate.""" + + top_p: float | None = Field(default=None, ge=0.0, le=1.0) + """Nucleus sampling parameter.""" + + frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) + """Penalize new tokens based on their frequency in the text so far.""" + + presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) + """Penalize new tokens based on whether they appear in the text so far.""" + + n: int = Field(default=1, ge=1) + """Number of chat completions to generate for each prompt.""" + + streaming: bool = False + """Whether to stream the results or not.""" + + base_url: str = Field(default="https://api.sarvam.ai/v1") + """Base URL for Sarvam API.""" + + api_key: SecretStr | None = Field(default=None) + """Sarvam API key.""" + + timeout: float | None = Field(default=None) + """Timeout for API requests in seconds.""" + + max_retries: int = Field(default=2, ge=0) + """Maximum number of retries for API requests.""" + + model_kwargs: dict[str, Any] = Field(default_factory=dict) + """Additional model parameters.""" + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: dict) -> dict: + """Validate that api key exists in environment.""" + values["api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "api_key", "SARVAM_API_KEY") + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "sarvam-chat" + + @property + def _identifying_params(self) -> dict[str, Any]: + """Get the identifying parameters.""" + return { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "n": self.n, + **self.model_kwargs, + } + + @property + def _default_params(self) -> dict[str, Any]: + """Get the default parameters for calling Sarvam API.""" + params = { + "model": self.model, + "temperature": self.temperature, + "n": self.n, + **self.model_kwargs, + } + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + if self.top_p is not None: + params["top_p"] = self.top_p + if self.frequency_penalty is not None: + params["frequency_penalty"] = self.frequency_penalty + if self.presence_penalty is not None: + params["presence_penalty"] = self.presence_penalty + return params + + def _create_chat_result(self, response: dict[str, Any]) -> ChatResult: + """Create a ChatResult from a Sarvam API response.""" + generations = [] + for choice in response.get("choices", []): + message_dict = choice.get("message", {}) + message = _convert_dict_to_message(message_dict) + generation = ChatGeneration( + message=message, + generation_info=dict( + finish_reason=choice.get("finish_reason"), + index=choice.get("index"), + ), + ) + generations.append(generation) + + token_usage = response.get("usage", {}) + llm_output = { + "token_usage": token_usage, + "model_name": response.get("model", self.model), + } + return ChatResult(generations=generations, llm_output=llm_output) + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Generate chat response.""" + message_dicts = [_convert_message_to_dict(m) for m in messages] + params = self._default_params + params.update(kwargs) + if stop is not None: + params["stop"] = stop + + headers = { + "Authorization": f"Bearer {self.api_key.get_secret_value()}", # type: ignore[union-attr] + "Content-Type": "application/json", + } + + payload = { + "messages": message_dicts, + **params, + } + + response = requests.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + timeout=self.timeout, + ) + + # Add better error handling + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + error_detail = "" + try: + error_detail = response.json() + logger.error(f"Sarvam API error: {error_detail}") + except Exception: + error_detail = response.text + logger.error(f"Sarvam API error: {error_detail}") + raise ValueError(f"Sarvam API request failed: {e}. Details: {error_detail}") + + return self._create_chat_result(response.json()) + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream chat response.""" + message_dicts = [_convert_message_to_dict(m) for m in messages] + params = self._default_params + params.update(kwargs) + params["stream"] = True + if stop is not None: + params["stop"] = stop + + headers = { + "Authorization": f"Bearer {self.api_key.get_secret_value()}", # type: ignore[union-attr] + "Content-Type": "application/json", + } + + payload = { + "messages": message_dicts, + **params, + } + + response = requests.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + timeout=self.timeout, + stream=True, + ) + response.raise_for_status() + + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str.strip() == "[DONE]": + break + try: + import json + + data = json.loads(data_str) + for choice in data.get("choices", []): + delta = choice.get("delta", {}) + content = delta.get("content", "") + if content: + chunk = ChatGenerationChunk( + message=AIMessageChunk(content=content) + ) + if run_manager: + run_manager.on_llm_new_token(content, chunk=chunk) + yield chunk + except Exception as e: + logger.warning(f"Error parsing stream: {e}") + continue + + @property + def _invocation_params(self) -> dict[str, Any]: + """Get the parameters used to invoke the model.""" + return self._default_params diff --git a/libs/partners/sarvam/langchain_sarvam/py.typed b/libs/partners/sarvam/langchain_sarvam/py.typed new file mode 100644 index 0000000000000..0baa9f58f0b26 --- /dev/null +++ b/libs/partners/sarvam/langchain_sarvam/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 - indicates this package supports type checking diff --git a/libs/partners/sarvam/py.typed b/libs/partners/sarvam/py.typed new file mode 100644 index 0000000000000..7632ecf77545c --- /dev/null +++ b/libs/partners/sarvam/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 diff --git a/libs/partners/sarvam/pyproject.toml b/libs/partners/sarvam/pyproject.toml new file mode 100644 index 0000000000000..56a7397fbfa9e --- /dev/null +++ b/libs/partners/sarvam/pyproject.toml @@ -0,0 +1,67 @@ +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "langchain-sarvam" +version = "1.0.0a1" +description = "An integration package connecting Sarvam AI and LangChain" +authors = ["LangChain"] +license = "MIT" +readme = "README.md" +repository = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/sarvam" +homepage = "https://docs.langchain.com/oss/python/integrations/providers/sarvam" + +[tool.poetry.dependencies] +python = ">=3.10,<4.0" +langchain-core = ">=1.0.0a7,<2.0.0" +requests = ">=2.31.0,<3.0.0" + +[tool.poetry.group.test.dependencies] +pytest = ">=7.3.0,<8.0.0" +freezegun = ">=1.2.2,<2.0.0" +pytest-mock = ">=3.10.0,<4.0.0" +syrupy = ">=4.0.2,<5.0.0" +pytest-watcher = ">=0.3.4,<1.0.0" +pytest-asyncio = ">=0.21.1,<1.0.0" +pytest-cov = ">=4.1.0,<5.0.0" +pytest-retry = ">=1.7.0,<1.8.0" +pytest-socket = ">=0.6.0,<1.0.0" +pytest-xdist = ">=3.6.1,<4.0.0" +langchain-core = ">=1.0.0a7,<2.0.0" +langchain-tests = { path = "../../standard-tests", develop = true } + +[tool.poetry.group.lint.dependencies] +ruff = ">=0.13.1,<0.14.0" + +[tool.poetry.group.dev.dependencies] +langchain-core = { path = "../../core", develop = true } + +[tool.poetry.group.test_integration.dependencies] +httpx = ">=0.27.0,<1.0.0" + +[tool.poetry.group.typing.dependencies] +mypy = ">=1.10.0,<2.0.0" +langchain-core = ">=1.0.0a7,<2.0.0" + +# Keep PEP 621 [project] for compatibility (LangChain uses dual metadata) +[project] +name = "langchain-sarvam" +version = "1.0.0a1" +description = "An integration package connecting Sarvam AI and LangChain" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.10,<4.0" +dependencies = [ + "langchain-core>=1.0.0a7,<2.0.0", + "requests>=2.31.0,<3.0.0", +] + +[project.urls] +homepage = "https://docs.langchain.com/oss/python/integrations/providers/sarvam" +repository = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/sarvam" +changelog = "https://github.com/langchain-ai/langchain/releases?q=%22langchain-sarvam%22" +docs = "https://reference.langchain.com/python/integrations/langchain_sarvam/" +twitter = "https://x.com/LangChainAI" +slack = "https://www.langchain.com/join-community" +reddit = "https://www.reddit.com/r/LangChain/" diff --git a/libs/partners/sarvam/scripts/__init__.py b/libs/partners/sarvam/scripts/__init__.py new file mode 100644 index 0000000000000..be34e6808da8b --- /dev/null +++ b/libs/partners/sarvam/scripts/__init__.py @@ -0,0 +1 @@ +"""Scripts for langchain-sarvam.""" diff --git a/libs/partners/sarvam/scripts/check_imports.py b/libs/partners/sarvam/scripts/check_imports.py new file mode 100644 index 0000000000000..bb01f3e8dd361 --- /dev/null +++ b/libs/partners/sarvam/scripts/check_imports.py @@ -0,0 +1,19 @@ +"""Check that the package can be imported.""" + +import sys +from importlib import import_module + + +def main() -> None: + """Check that the package can be imported.""" + try: + import_module("langchain_sarvam") + import_module("langchain_sarvam.chat_models") + sys.exit(0) + except ImportError as e: + sys.stderr.write(f"✗ Import failed: {e}\n") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/libs/partners/sarvam/tests/__init__.py b/libs/partners/sarvam/tests/__init__.py new file mode 100644 index 0000000000000..3bc9f2013a6c9 --- /dev/null +++ b/libs/partners/sarvam/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for langchain-sarvam.""" diff --git a/libs/partners/sarvam/tests/confest.py b/libs/partners/sarvam/tests/confest.py new file mode 100644 index 0000000000000..cb65eeb2fd1cc --- /dev/null +++ b/libs/partners/sarvam/tests/confest.py @@ -0,0 +1,21 @@ +"""Pytest configuration for langchain-sarvam tests.""" + +import os + +import pytest + + +def pytest_collection_modifyitems(config: pytest.Config, items: list) -> None: + """Mark tests that require API keys.""" + for item in items: + if "integration_tests" in str(item.fspath): + item.add_marker(pytest.mark.scheduled) + + +@pytest.fixture(scope="session") +def sarvam_api_key() -> str: + """Get Sarvam API key from environment.""" + key = os.environ.get("SARVAM_API_KEY", "") + if not key: + pytest.skip("SARVAM_API_KEY not set") + return key diff --git a/libs/partners/sarvam/tests/conftest.py b/libs/partners/sarvam/tests/conftest.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/partners/sarvam/tests/integration_tests/__init__.py b/libs/partners/sarvam/tests/integration_tests/__init__.py new file mode 100644 index 0000000000000..c8675d9575c1f --- /dev/null +++ b/libs/partners/sarvam/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""Integration tests for langchain-sarvam.""" diff --git a/libs/partners/sarvam/tests/integration_tests/test_chat_models.py b/libs/partners/sarvam/tests/integration_tests/test_chat_models.py new file mode 100644 index 0000000000000..56849a70518b2 --- /dev/null +++ b/libs/partners/sarvam/tests/integration_tests/test_chat_models.py @@ -0,0 +1,68 @@ +"""Integration tests for ChatSarvam. + +These tests require a valid SARVAM_API_KEY environment variable. +""" + +import pytest +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_tests.integration_tests import ChatModelIntegrationTests + +from langchain_sarvam import ChatSarvam + + +class TestChatSarvamIntegration(ChatModelIntegrationTests): + """Integration tests for ChatSarvam.""" + + @property + def chat_model_class(self) -> type[ChatSarvam]: + """Get chat model class.""" + return ChatSarvam + + @property + def chat_model_params(self) -> dict: + """Get chat model parameters.""" + return {"model": "sarvam-m", "temperature": 0} + + +@pytest.mark.scheduled +def test_chat_sarvam_invoke() -> None: + """Test basic invoke functionality.""" + chat = ChatSarvam(model="sarvam-m") + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Say 'Hello World' and nothing else."), + ] + response = chat.invoke(messages) + assert isinstance(response.content, str) + assert len(response.content) > 0 + + +@pytest.mark.scheduled +def test_chat_sarvam_streaming() -> None: + """Test streaming functionality.""" + chat = ChatSarvam(model="sarvam-m", streaming=True) + messages = [HumanMessage(content="Count from 1 to 3")] + chunks = [] + for chunk in chat.stream(messages): + chunks.append(chunk.content) + + full_response = "".join(chunks) + assert len(full_response) > 0 + + +@pytest.mark.scheduled +def test_chat_sarvam_with_temperature() -> None: + """Test with different temperature settings.""" + chat = ChatSarvam(model="sarvam-m", temperature=0.9) + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response.content, str) + + +@pytest.mark.scheduled +def test_chat_sarvam_with_max_tokens() -> None: + """Test with max_tokens parameter.""" + chat = ChatSarvam(model="sarvam-m", max_tokens=50) + message = HumanMessage(content="Tell me a long story") + response = chat.invoke([message]) + assert isinstance(response.content, str) diff --git a/libs/partners/sarvam/tests/unit_tests/__init__.py b/libs/partners/sarvam/tests/unit_tests/__init__.py new file mode 100644 index 0000000000000..3158b3f45c395 --- /dev/null +++ b/libs/partners/sarvam/tests/unit_tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for langchain-sarvam.""" diff --git a/libs/partners/sarvam/tests/unit_tests/test_chat_models.py b/libs/partners/sarvam/tests/unit_tests/test_chat_models.py new file mode 100644 index 0000000000000..fbd745e6dddbd --- /dev/null +++ b/libs/partners/sarvam/tests/unit_tests/test_chat_models.py @@ -0,0 +1,49 @@ +"""Unit tests for ChatSarvam.""" + +from typing import Any + +from langchain_tests.unit_tests import ChatModelUnitTests + +from langchain_sarvam.chat_models import ChatSarvam + + +class TestChatSarvamUnit(ChatModelUnitTests): + """Unit tests for ChatSarvam.""" + + @property + def chat_model_class(self) -> type[ChatSarvam]: + """Get chat model class.""" + return ChatSarvam + + @property + def chat_model_params(self) -> dict[str, Any]: + """Get chat model parameters.""" + return { + "model": "sarvam-1", + "api_key": "test-api-key", + "temperature": 0.7, + } + + +def test_initialization() -> None: + """Test ChatSarvam initialization.""" + llm = ChatSarvam(model="sarvam-1", api_key="test-key") + assert llm.model == "sarvam-1" + assert llm.temperature == 0.7 + + +def test_sarvam_model_param() -> None: + """Test model parameter.""" + llm = ChatSarvam(model="sarvam-2", api_key="test-key") + assert llm.model == "sarvam-2" + + +def test_sarvam_model_kwargs() -> None: + """Test model kwargs.""" + llm = ChatSarvam( + model="sarvam-1", + api_key="test-key", + model_kwargs={"custom_param": "value"}, + ) + params = llm._identifying_params + assert params["custom_param"] == "value" diff --git a/libs/partners/sarvam/uv.lock b/libs/partners/sarvam/uv.lock new file mode 100644 index 0000000000000..d6dd03ba875f8 --- /dev/null +++ b/libs/partners/sarvam/uv.lock @@ -0,0 +1,56 @@ +schema-version = 2 + +[metadata] +lock-version = "1.0" +python-versions = ["3.10", "3.11", "3.12", "3.13"] + +[package] +anyio = {version = "4.3.0", extras = ["trio"]} +certifi = {version = "2025.8.3"} +charset-normalizer = {version = "3.4.3"} +coverage = {version = "7.4.4"} +distlib = {version = "0.3.8"} +exceptiongroup = {version = "1.2.0"} +filelock = {version = "3.13.3"} +freezegun = {version = "1.2.2"} +httpx = {version = "0.27.0"} +idna = {version = "3.6"} +iniconfig = {version = "2.0.0"} +langchain-core = {path = "../../core", editable = true} +langchain-tests = {path = "../../standard-tests", editable = true} +mypy = {version = "1.9.0"} +mypy-extensions = {version = "1.0.0"} +packaging = {version = "25.0"} +pluggy = {version = "1.4.0"} +pytest = {version = "7.3.0"} +pytest-asyncio = {version = "0.21.1"} +pytest-cov = {version = "4.1.0"} +pytest-mock = {version = "3.10.0"} +pytest-retry = {version = "1.7.0"} +pytest-socket = {version = "0.6.0"} +pytest-xdist = {version = "3.6.1"} +python-dateutil = {version = "2.9.0.post0"} +requests = {version = "2.31.0"} +ruff = {version = "0.13.1"} +six = {version = "1.16.0"} +sniffio = {version = "1.3.1"} +syrupy = {version = "4.6.0"} +typing-extensions = {version = "4.10.0"} +urllib3 = {version = "2.2.1"} + +[package.dependencies] +anyio = [ + {version = "4.3.0", python = ">=3.8,<3.10"}, +] +exceptiongroup = [ + {version = "1.2.0", python = "<3.11"}, +] +tomli = [ + {version = "2.0.1", python = "<3.11"}, +] +typing-extensions = [ + {version = "4.10.0", python = "<3.13"}, +] + +[metadata.files] +# Files hashes would normally be here, omitted for brevity