Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies = [
"mlflow>=2.21.2",
"nltk>=3.9.1",
"numpy>=2.2.4",
"omegaconf>=2.3.0",
"pandas>=2.2.3",
"pendulum>=3.0.0",
"pyarrow>=19.0.1",
Expand Down
4 changes: 2 additions & 2 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ unset AWS_SESSION_TOKEN

export MLFLOW_S3_ENDPOINT_URL="https://$AWS_S3_ENDPOINT"
export MLFLOW_TRACKING_URI=https://projet-ape-mlflow.user.lab.sspcloud.fr
export MLFLOW_MODEL_NAME=FastText-pytorch
export MLFLOW_MODEL_VERSION="2"
export MLFLOW_MODEL_NAME=test_wrapper_pytorch
export MLFLOW_MODEL_VERSION="10"
export API_USERNAME=username
export API_PASSWORD=password
export AUTH_API=False
880 changes: 0 additions & 880 deletions src/api/data/libs.yaml

This file was deleted.

12 changes: 3 additions & 9 deletions src/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Annotated

import mlflow
import yaml
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBasicCredentials

from api.routes import predict_batch, predict_single
from api.routes import predict
from utils.logging import configure_logging
from utils.security import get_credentials

Expand All @@ -26,10 +24,7 @@ async def lifespan(app: FastAPI):
logger.info("🚀 Starting API lifespan")

model_uri = f"models:/{os.environ['MLFLOW_MODEL_NAME']}/{os.environ['MLFLOW_MODEL_VERSION']}"
app.state.model = mlflow.pytorch.load_model(model_uri)

libs_path = Path("api/data/libs.yaml")
app.state.libs = yaml.safe_load(libs_path.read_text())
app.state.model = mlflow.pyfunc.load_model(model_uri)

yield
logger.info("🛑 Shutting down API lifespan")
Expand All @@ -42,8 +37,7 @@ async def lifespan(app: FastAPI):
version="0.0.1",
)

app.include_router(predict_single.router)
app.include_router(predict_batch.router)
app.include_router(predict.router)

app.add_middleware(
CORSMiddleware,
Expand Down
58 changes: 58 additions & 0 deletions src/api/routes/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Annotated, List

from fastapi import APIRouter, Depends, Request
from fastapi.security import HTTPBasicCredentials

from api.models.forms import BatchForms
from api.models.responses import PredictionResponse
from utils.preprocessing import mappings
from utils.security import get_credentials

APE_NIV5_MAPPING = mappings["nace2025"]
INV_APE_NIV5_MAPPING = {v: k for k, v in APE_NIV5_MAPPING.items()}

router = APIRouter(prefix="/predict", tags=["Predict NACE code for a list of activities"])


@router.post("/predict", response_model=List[PredictionResponse])
async def predict(
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
request: Request,
forms: BatchForms,
nb_echos_max: int = 5,
prob_min: float = 0.01,
num_workers: int = 1,
batch_size: int = 1,
):
"""
Endpoint for predicting batches of data.

Args:
credentials (HTTPBasicCredentials): The credentials for authentication.
forms (Forms): The input data in the form of Forms object.
nb_echos_max (int, optional): The maximum number of predictions to return. Defaults to 5.
prob_min (float, optional): The minimum probability threshold for predictions. Defaults to 0.01.
num_workers (int, optional): Number of CPU for multiprocessing in Dataloader. Defaults to 1.
batch_size (int, optional): Size of a batch for batch prediction.

For single predictions, we recommend keeping num_workers and batch_size to 1 for better performance.
For batched predictions, consider increasing these two parameters (num_workers can range from 4 to 12, batch size can be increased up to 256) to optimize performance.

Returns:
list: The list of predicted responses.
"""
input_data = forms.forms

params_dict = {
"nb_echos_max": nb_echos_max,
"prob_min": prob_min,
"dataloader_params": {
"pin_memory": False,
"persistent_workers": False,
"num_workers": num_workers,
"batch_size": batch_size,
},
}

output = request.app.state.model.predict(input_data, params=params_dict)
return [out.model_dump() for out in output]
82 changes: 0 additions & 82 deletions src/api/routes/predict_batch.py

This file was deleted.

77 changes: 0 additions & 77 deletions src/api/routes/predict_single.py

This file was deleted.

1 change: 0 additions & 1 deletion src/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def configure_logging():
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("codification_ape_log_file.log"),
logging.StreamHandler(),
],
)
Expand Down
21 changes: 21 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.