Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,26 @@ class SaveDocumentRequest(BaseModel):
data: dict


def _get_features(
async def _get_features(
request: Union[GetOnlineFeaturesRequest, GetOnlineDocumentsRequest],
store: "feast.FeatureStore",
):
if request.feature_service:
feature_service = store.get_feature_service(
request.feature_service, allow_cache=True
feature_service = await run_in_threadpool(
store.get_feature_service, request.feature_service, allow_cache=True
)
assert_permissions(
resource=feature_service, actions=[AuthzedAction.READ_ONLINE]
)
features = feature_service # type: ignore
else:
all_feature_views, all_on_demand_feature_views = (
utils._get_feature_views_to_use(
store.registry,
store.project,
request.features,
allow_cache=True,
hide_dummy_entity=False,
)
all_feature_views, all_on_demand_feature_views = await run_in_threadpool(
utils._get_feature_views_to_use,
store.registry,
store.project,
request.features,
allow_cache=True,
hide_dummy_entity=False,
)
for feature_view in all_feature_views:
assert_permissions(
Expand Down Expand Up @@ -230,7 +229,7 @@ async def lifespan(app: FastAPI):
)
async def get_online_features(request: GetOnlineFeaturesRequest) -> Dict[str, Any]:
# Initialize parameters for FeatureStore.get_online_features(...) call
features = await run_in_threadpool(_get_features, request, store)
features = await _get_features(request, store)

read_params = dict(
features=features,
Expand Down Expand Up @@ -265,7 +264,7 @@ async def retrieve_online_documents(
"This endpoint is in alpha and will be moved to /get-online-features when stable."
)
# Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call
features = await run_in_threadpool(_get_features, request, store)
features = await _get_features(request, store)

read_params = dict(features=features, query=request.query, top_k=request.top_k)
if request.api_version == 2 and request.query_string is not None:
Expand Down Expand Up @@ -342,26 +341,31 @@ async def push(request: PushFeaturesRequest) -> None:
else:
store.push(**push_params)

def _get_feast_object(
async def _get_feast_object(
feature_view_name: str, allow_registry_cache: bool
) -> FeastObject:
try:
return store.get_stream_feature_view( # type: ignore
feature_view_name, allow_registry_cache=allow_registry_cache
return await run_in_threadpool(
store.get_stream_feature_view,
feature_view_name,
allow_registry_cache=allow_registry_cache,
)
except FeatureViewNotFoundException:
return store.get_feature_view( # type: ignore
feature_view_name, allow_registry_cache=allow_registry_cache
return await run_in_threadpool(
store.get_feature_view,
feature_view_name,
allow_registry_cache=allow_registry_cache,
)

@app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)])
def write_to_online_store(request: WriteToFeatureStoreRequest) -> None:
async def write_to_online_store(request: WriteToFeatureStoreRequest) -> None:
df = pd.DataFrame(request.df)
feature_view_name = request.feature_view_name
allow_registry_cache = request.allow_registry_cache
resource = _get_feast_object(feature_view_name, allow_registry_cache)
resource = await _get_feast_object(feature_view_name, allow_registry_cache)
assert_permissions(resource=resource, actions=[AuthzedAction.WRITE_ONLINE])
store.write_to_online_store(
await run_in_threadpool(
store.write_to_online_store,
feature_view_name=feature_view_name,
df=df,
allow_registry_cache=allow_registry_cache,
Expand Down Expand Up @@ -428,10 +432,11 @@ async def chat_ui():
return Response(content=content, media_type="text/html")

@app.post("/materialize", dependencies=[Depends(inject_user_details)])
def materialize(request: MaterializeRequest) -> None:
async def materialize(request: MaterializeRequest) -> None:
for feature_view in request.feature_views or []:
resource = await _get_feast_object(feature_view, True)
assert_permissions(
resource=_get_feast_object(feature_view, True),
resource=resource,
actions=[AuthzedAction.WRITE_ONLINE],
)

Expand All @@ -450,22 +455,26 @@ def materialize(request: MaterializeRequest) -> None:
start_date = utils.make_tzaware(parser.parse(request.start_ts))
end_date = utils.make_tzaware(parser.parse(request.end_ts))

store.materialize(
await run_in_threadpool(
store.materialize,
start_date,
end_date,
request.feature_views,
disable_event_timestamp=request.disable_event_timestamp,
)

@app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)])
def materialize_incremental(request: MaterializeIncrementalRequest) -> None:
async def materialize_incremental(request: MaterializeIncrementalRequest) -> None:
for feature_view in request.feature_views or []:
resource = await _get_feast_object(feature_view, True)
assert_permissions(
resource=_get_feast_object(feature_view, True),
resource=resource,
actions=[AuthzedAction.WRITE_ONLINE],
)
store.materialize_incremental(
utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views
await run_in_threadpool(
store.materialize_incremental,
utils.make_tzaware(parser.parse(request.end_ts)),
request.feature_views,
)

@app.exception_handler(Exception)
Expand Down
130 changes: 66 additions & 64 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,24 @@ class DynamoDBOnlineStore(OnlineStore):
Attributes:
_dynamodb_client: Boto3 DynamoDB client.
_dynamodb_resource: Boto3 DynamoDB resource.
_aioboto_session: Async boto session.
_aioboto_client: Async boto client.
_aioboto_context_stack: Async context stack.
"""

_dynamodb_client = None
_dynamodb_resource = None

def __init__(self):
super().__init__()
self._aioboto_session = None
self._aioboto_client = None
self._aioboto_context_stack = None

async def initialize(self, config: RepoConfig):
online_config = config.online_store

await _get_aiodynamodb_client(
await self._get_aiodynamodb_client(
online_config.region,
online_config.max_pool_connections,
online_config.keepalive_timeout,
Expand All @@ -127,7 +136,59 @@ async def initialize(self, config: RepoConfig):
)

async def close(self):
await _aiodynamodb_close()
await self._aiodynamodb_close()

def _get_aioboto_session(self):
if self._aioboto_session is None:
logger.debug("initializing the aiobotocore session")
self._aioboto_session = session.get_session()
return self._aioboto_session

async def _get_aiodynamodb_client(
self,
region: str,
max_pool_connections: int,
keepalive_timeout: float,
connect_timeout: Union[int, float],
read_timeout: Union[int, float],
total_max_retry_attempts: Union[int, None],
retry_mode: Union[Literal["legacy", "standard", "adaptive"], None],
):
if self._aioboto_client is None:
logger.debug("initializing the aiobotocore dynamodb client")

retries: Dict[str, Any] = {}
if total_max_retry_attempts is not None:
retries["total_max_attempts"] = total_max_retry_attempts
if retry_mode is not None:
retries["mode"] = retry_mode

client_context = self._get_aioboto_session().create_client(
"dynamodb",
region_name=region,
config=AioConfig(
max_pool_connections=max_pool_connections,
connect_timeout=connect_timeout,
read_timeout=read_timeout,
retries=retries if retries else None,
connector_args={"keepalive_timeout": keepalive_timeout},
),
)
self._aioboto_context_stack = contextlib.AsyncExitStack()
self._aioboto_client = (
await self._aioboto_context_stack.enter_async_context(client_context)
)
return self._aioboto_client

async def _aiodynamodb_close(self):
if self._aioboto_client:
await self._aioboto_client.close()
self._aioboto_client = None
if self._aioboto_context_stack:
await self._aioboto_context_stack.aclose()
self._aioboto_context_stack = None
if self._aioboto_session:
self._aioboto_session = None

@property
def async_supported(self) -> SupportedAsyncMethods:
Expand Down Expand Up @@ -362,7 +423,7 @@ async def online_write_batch_async(
_to_client_write_item(config, entity_key, features, timestamp)
for entity_key, features, timestamp, _ in _latest_data_to_write(data)
]
client = await _get_aiodynamodb_client(
client = await self._get_aiodynamodb_client(
online_config.region,
online_config.max_pool_connections,
online_config.keepalive_timeout,
Expand Down Expand Up @@ -473,7 +534,7 @@ def to_tbl_resp(raw_client_response):
batches.append(batch)
entity_id_batches.append(entity_id_batch)

client = await _get_aiodynamodb_client(
client = await self._get_aiodynamodb_client(
online_config.region,
online_config.max_pool_connections,
online_config.keepalive_timeout,
Expand Down Expand Up @@ -627,66 +688,7 @@ def _to_client_batch_get_payload(online_config, table_name, batch):
}


_aioboto_session = None
_aioboto_client = None
_aioboto_context_stack = None


def _get_aioboto_session():
global _aioboto_session
if _aioboto_session is None:
logger.debug("initializing the aiobotocore session")
_aioboto_session = session.get_session()
return _aioboto_session


async def _get_aiodynamodb_client(
region: str,
max_pool_connections: int,
keepalive_timeout: float,
connect_timeout: Union[int, float],
read_timeout: Union[int, float],
total_max_retry_attempts: Union[int, None],
retry_mode: Union[Literal["legacy", "standard", "adaptive"], None],
):
global _aioboto_client, _aioboto_context_stack
if _aioboto_client is None:
logger.debug("initializing the aiobotocore dynamodb client")

retries: Dict[str, Any] = {}
if total_max_retry_attempts is not None:
retries["total_max_attempts"] = total_max_retry_attempts
if retry_mode is not None:
retries["mode"] = retry_mode

client_context = _get_aioboto_session().create_client(
"dynamodb",
region_name=region,
config=AioConfig(
max_pool_connections=max_pool_connections,
connect_timeout=connect_timeout,
read_timeout=read_timeout,
retries=retries if retries else None,
connector_args={"keepalive_timeout": keepalive_timeout},
),
)
_aioboto_context_stack = contextlib.AsyncExitStack()
_aioboto_client = await _aioboto_context_stack.enter_async_context(
client_context
)
return _aioboto_client


async def _aiodynamodb_close():
global _aioboto_client, _aioboto_session, _aioboto_context_stack
if _aioboto_client:
await _aioboto_client.close()
_aioboto_client = None
if _aioboto_context_stack:
await _aioboto_context_stack.aclose()
_aioboto_context_stack = None
if _aioboto_session:
_aioboto_session = None
# Global async client functions removed - now using instance methods


def _initialize_dynamodb_client(
Expand Down
28 changes: 28 additions & 0 deletions sdk/python/tests/unit/test_feature_server_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from unittest.mock import AsyncMock, MagicMock

from fastapi.testclient import TestClient

from feast.feature_server import get_app
from feast.online_response import OnlineResponse
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse


def test_async_get_online_features():
"""Test that async get_online_features endpoint works correctly"""
fs = MagicMock()
fs._get_provider.return_value.async_supported.online.read = True
fs.get_online_features_async = AsyncMock(
return_value=OnlineResponse(GetOnlineFeaturesResponse())
)
fs.get_feature_service = MagicMock()
fs.initialize = AsyncMock()
fs.close = AsyncMock()

client = TestClient(get_app(fs))
response = client.post(
"/get-online-features",
json={"features": ["test:feature"], "entities": {"entity_id": [123]}},
)

assert response.status_code == 200
assert fs.get_online_features_async.await_count == 1
Loading