Skip to content

Commit 816d1b2

Browse files
committed
Inference class scans for plugins
1 parent aa7a816 commit 816d1b2

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

pinecone/data/features/inference/inference.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1+
import logging
12
from typing import Optional, Dict, List, Union, Any
23

34
from pinecone.openapi_support import ApiClient
45
from pinecone.core.openapi.inference.apis import InferenceApi
56
from pinecone.core.openapi.inference.models import EmbeddingsList, RerankResult
67
from pinecone.core.openapi.inference import API_VERSION
7-
from pinecone.utils import setup_openapi_client
8+
from pinecone.utils import setup_openapi_client, build_plugin_setup_client
9+
10+
from pinecone_plugin_interface import load_and_install as install_plugins
811

912
from .inference_request_builder import (
1013
InferenceRequestBuilder,
1114
EmbedModel as EmbedModelEnum,
1215
RerankModel as RerankModelEnum,
1316
)
1417

18+
logger = logging.getLogger(__name__)
19+
1520

1621
class Inference:
1722
"""
@@ -27,6 +32,8 @@ class Inference:
2732

2833
def __init__(self, config, openapi_config, **kwargs):
2934
self.config = config
35+
self.openapi_config = openapi_config
36+
self.pool_threads = kwargs.get("pool_threads", 1)
3037

3138
self.__inference_api = setup_openapi_client(
3239
api_client_klass=ApiClient,
@@ -37,6 +44,23 @@ def __init__(self, config, openapi_config, **kwargs):
3744
api_version=API_VERSION,
3845
)
3946

47+
self.load_plugins()
48+
49+
def load_plugins(self):
50+
"""@private"""
51+
try:
52+
# I don't expect this to ever throw, but wrapping this in a
53+
# try block just in case to make sure a bad plugin doesn't
54+
# halt client initialization.
55+
openapi_client_builder = build_plugin_setup_client(
56+
config=self.config,
57+
openapi_config=self.openapi_config,
58+
pool_threads=self.pool_threads,
59+
)
60+
install_plugins(self, openapi_client_builder)
61+
except Exception as e:
62+
logger.error(f"Error loading plugins: {e}")
63+
4064
def embed(
4165
self,
4266
model: Union[EmbedModel, str],

0 commit comments

Comments
 (0)