Skip to content

[Feat] asyncio for grpc data interactions #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/actions/create-index/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ outputs:
index_name:
description: 'The name of the index, including randomized suffix'
value: ${{ steps.create-index.outputs.index_name }}
index_host:
description: 'The host of the index'
value: ${{ steps.create-index.outputs.index_host }}
index_dimension:
description: 'The dimension of the index'
value: ${{ steps.create-index.outputs.index_dimension }}
index_metric:
description: 'The metric of the index'
value: ${{ steps.create-index.outputs.index_metric }}

runs:
using: 'composite'
Expand Down
55 changes: 55 additions & 0 deletions .github/actions/test-data-plane-asyncio/action.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: 'Test Data Plane'
description: 'Runs tests on the Pinecone data plane'

inputs:
metric:
description: 'The metric of the index'
required: true
dimension:
description: 'The dimension of the index'
required: true
host:
description: 'The host of the index'
required: true
use_grpc:
description: 'Whether to use gRPC or REST'
required: true
freshness_timeout_seconds:
description: 'The number of seconds to wait for the index to become fresh'
required: false
default: '60'
PINECONE_API_KEY:
description: 'The Pinecone API key'
required: true

outputs:
index_name:
description: 'The name of the index, including randomized suffix'
value: ${{ steps.create-index.outputs.index_name }}

runs:
using: 'composite'
steps:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python_version }}

- name: Setup Poetry
uses: ./.github/actions/setup-poetry
with:
include_grpc: ${{ inputs.use_grpc }}
include_dev: 'true'

- name: Run data plane tests
id: data-plane-tests
shell: bash
run: poetry run pytest tests/integration/data_asyncio
env:
PINECONE_API_KEY: ${{ inputs.PINECONE_API_KEY }}
USE_GRPC: ${{ inputs.use_grpc }}
METRIC: ${{ inputs.metric }}
INDEX_HOST: ${{ inputs.host }}
DIMENSION: ${{ inputs.dimension }}
SPEC: ${{ inputs.spec }}
FRESHNESS_TIMEOUT_SECONDS: ${{ inputs.freshness_timeout_seconds }}
27 changes: 13 additions & 14 deletions .github/workflows/alpha-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ on:
default: 'rc1'

jobs:
unit-tests:
uses: './.github/workflows/testing-unit.yaml'
secrets: inherit
integration-tests:
uses: './.github/workflows/testing-integration.yaml'
secrets: inherit
dependency-tests:
uses: './.github/workflows/testing-dependency.yaml'
secrets: inherit
# unit-tests:
# uses: './.github/workflows/testing-unit.yaml'
# secrets: inherit
# integration-tests:
# uses: './.github/workflows/testing-integration.yaml'
# secrets: inherit
# dependency-tests:
# uses: './.github/workflows/testing-dependency.yaml'
# secrets: inherit

pypi:
uses: './.github/workflows/publish-to-pypi.yaml'
needs:
- unit-tests
- integration-tests
- dependency-tests
# needs:
# - unit-tests
# - integration-tests
# - dependency-tests
with:
isPrerelease: true
ref: ${{ inputs.ref }}
Expand All @@ -49,4 +49,3 @@ jobs:
secrets:
PYPI_USERNAME: __token__
PYPI_PASSWORD: ${{ secrets.PROD_PYPI_PUBLISH_TOKEN }}

58 changes: 58 additions & 0 deletions .github/workflows/testing-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,64 @@ jobs:
PINECONE_DEBUG_CURL: 'true'
PINECONE_API_KEY: '${{ secrets.PINECONE_API_KEY }}'

data-plane-setup:
name: Create index
runs-on: ubuntu-latest
outputs:
index_name: ${{ steps.setup-index.outputs.index_name }}
index_host: ${{ steps.setup-index.outputs.index_host }}
index_dimension: ${{ steps.setup-index.outputs.index_dimension }}
index_metric: ${{ steps.setup-index.outputs.index_metric }}
steps:
- uses: actions/checkout@v4
- name: Create index
id: setup-index
uses: ./.github/actions/create-index
timeout-minutes: 5
with:
dimension: 100
metric: 'cosine'
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}


test-data-plane-asyncio:
name: Data plane asyncio integration tests
runs-on: ubuntu-latest
needs:
- data-plane-setup
outputs:
index_name: ${{ needs.data-plane-setup.outputs.index_name }}
strategy:
fail-fast: false
matrix:
python_version: [3.8, 3.12]
use_grpc: [true]
spec:
- '{ "asyncio": { "environment": "us-east1-gcp" }}'
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/test-data-plane-asyncio
with:
python_version: '${{ matrix.python_version }}'
use_grpc: '${{ matrix.use_grpc }}'
metric: '${{ needs.data-plane-setup.outputs.index_metric }}'
dimension: '${{ needs.data-plane-setup.outputs.index_dimension }}'
host: '${{ needs.data-plane-setup.outputs.index_host }}'
PINECONE_API_KEY: '${{ secrets.PINECONE_API_KEY }}'
freshness_timeout_seconds: 600

data-plane-asyncio-cleanup:
name: Deps cleanup
runs-on: ubuntu-latest
needs:
- test-data-plane-asyncio
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/delete-index
with:
index_name: '${{ needs.test-data-plane-asyncio.outputs.index_name }}'
PINECONE_API_KEY: '${{ secrets.PINECONE_API_KEY }}'

data-plane-serverless:
name: Data plane serverless integration tests
runs-on: ubuntu-latest
Expand Down
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ venv.bak/
.ropeproject

# pdocs documentation
# We want to exclude any locally generated artifacts, but we rely on
# We want to exclude any locally generated artifacts, but we rely on
# keeping documentation assets in the docs/ folder.
docs/*
!docs/pinecone-python-client-fork.png
Expand All @@ -155,4 +155,6 @@ dmypy.json
*.hdf5
*~

tests/integration/proxy_config/logs
tests/integration/proxy_config/logs
*.parquet
app*.py
16 changes: 9 additions & 7 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import logging
from typing import Optional, Dict, Any, Union, List, Tuple, Literal
from typing import Optional, Dict, Any, Union, Literal

from .index_host_store import IndexHostStore

Expand All @@ -10,7 +10,12 @@
from pinecone.core.openapi.shared.api_client import ApiClient


from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import (
normalize_host,
setup_openapi_client,
build_plugin_setup_client,
parse_non_empty_args,
)
from pinecone.core.openapi.control.models import (
CreateCollectionRequest,
CreateIndexRequest,
Expand Down Expand Up @@ -317,9 +322,6 @@ def create_index(

api_instance = self.index_api

def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
return {arg_name: val for arg_name, val in args if val is not None}

if deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtection(deletion_protection)
else:
Expand All @@ -329,7 +331,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
if "serverless" in spec:
index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
elif "pod" in spec:
args_dict = _parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("environment", spec["pod"].get("environment")),
("metadata_config", spec["pod"].get("metadata_config")),
Expand All @@ -351,7 +353,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region)
)
elif isinstance(spec, PodSpec):
args_dict = _parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("replicas", spec.replicas),
("shards", spec.shards),
Expand Down
3 changes: 3 additions & 0 deletions pinecone/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from .exceptions import PineconeConfigurationError, PineconeProtocolError, ListConversionException

PineconeNotFoundException = NotFoundException

__all__ = [
"PineconeConfigurationError",
"PineconeProtocolError",
Expand All @@ -22,6 +24,7 @@
"PineconeApiKeyError",
"PineconeApiException",
"NotFoundException",
"PineconeNotFoundException",
"UnauthorizedException",
"ForbiddenException",
"ServiceException",
Expand Down
1 change: 1 addition & 0 deletions pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"""

from .index_grpc import GRPCIndex
from .index_grpc_asyncio import GRPCIndexAsyncio
from .pinecone import PineconeGRPC
from .config import GRPCClientConfig
from .future import PineconeGrpcFuture
Expand Down
17 changes: 11 additions & 6 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pinecone import Config
from .config import GRPCClientConfig
from .grpc_runner import GrpcRunner
from .utils import normalize_endpoint
from concurrent.futures import ThreadPoolExecutor

from pinecone_plugin_interface import load_and_install as install_plugins
Expand All @@ -22,8 +23,6 @@ class GRPCIndexBase(ABC):
Base class for grpc-based interaction with Pinecone indexes
"""

_pool = None

def __init__(
self,
index_name: str,
Expand All @@ -32,6 +31,7 @@ def __init__(
grpc_config: Optional[GRPCClientConfig] = None,
pool_threads: Optional[int] = None,
_endpoint_override: Optional[str] = None,
use_asyncio: Optional[bool] = False,
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
Expand All @@ -43,7 +43,7 @@ def __init__(
index_name=index_name, config=config, grpc_config=self.grpc_client_config
)
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=use_asyncio
)
self._channel = channel or self._gen_channel()
self.stub = self.stub_class(self._channel)
Expand Down Expand Up @@ -74,9 +74,7 @@ def stub_class(self):
pass

def _endpoint(self):
grpc_host = self.config.host.replace("https://", "")
if ":" not in grpc_host:
grpc_host = f"{grpc_host}:443"
grpc_host = normalize_endpoint(self.config.host)
return self._endpoint_override if self._endpoint_override else grpc_host

def _gen_channel(self):
Expand Down Expand Up @@ -111,3 +109,10 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
self.close()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self.close()
return True
Loading
Loading