Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@

# The name of the deployed pipeline in Databricks. Must match directly.
PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline"
CUSTOM_INFERENCE_JOB_NAME = "edvise_github_sourced_custom_inference_pipeline"


class DatabricksInferenceRunRequest(BaseModel):
"""Databricks parameters for an inference run."""
class DatabricksPDPInferenceRunRequest(BaseModel):
"""Databricks parameters for a PDP inference run."""

inst_name: str
# Note that the following should be the filepath.
Expand All @@ -50,6 +51,18 @@ class DatabricksInferenceRunRequest(BaseModel):
gcp_external_bucket_name: str


class DatabricksCustomInferenceRunRequest(BaseModel):
"""Databricks parameters for a custom schools inference run."""

inst_name: str
model_name: str
config_file_name: str
features_table_name: str
# The email where notifications will get sent.
email: str
gcp_external_bucket_name: str


class DatabricksInferenceRunResponse(BaseModel):
"""Databricks parameters for an inference run."""

Expand Down Expand Up @@ -186,7 +199,7 @@ def setup_new_inst(self, inst_name: str) -> None:
# E.g. there is one PDP inference pipeline, so one PDP inference function here.

def run_pdp_inference(
self, req: DatabricksInferenceRunRequest
self, req: DatabricksPDPInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers PDP inference Databricks run."""
LOGGER.info(f"Running PDP inference for institution: {req.inst_name}")
Expand Down Expand Up @@ -264,6 +277,73 @@ def run_pdp_inference(

return DatabricksInferenceRunResponse(job_run_id=run_id)

def run_custom_inference(
self, req: DatabricksCustomInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers custom schools inference Databricks run."""
LOGGER.info(f"Running custom inference for institution: {req.inst_name}")
try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
LOGGER.info("Successfully created Databricks WorkspaceClient.")
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(
f"run_custom_inference(): Workspace client initialization failed: {e}"
)

db_inst_name = databricksify_inst_name(req.inst_name)
pipeline_type = CUSTOM_INFERENCE_JOB_NAME

try:
job = next(w.jobs.list(name=pipeline_type), None)
if not job or job.job_id is None:
raise ValueError(
f"run_custom_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'."
)
job_id = job.job_id
LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}")
except Exception as e:
LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.")
raise ValueError(f"run_custom_inference(): Failed to find job: {e}")

try:
run_job: Any = w.jobs.run_now(
job_id,
job_parameters={
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars[
"DATABRICKS_WORKSPACE"
],
"model_name": req.model_name,
"config_file_name": req.config_file_name,
"features_table_name": req.features_table_name,
"gcp_bucket_name": req.gcp_external_bucket_name,
"datakind_notification_email": req.email,
"DK_CC_EMAIL": req.email,
},
)
LOGGER.info(
f"Successfully triggered job run. Run ID: {run_job.response.run_id}"
)
except Exception as e:
LOGGER.exception("Failed to run the custom inference job.")
raise ValueError(f"run_custom_inference(): Job could not be run: {e}")

if not run_job.response or run_job.response.run_id is None:
raise ValueError("run_custom_inference(): Job did not return a valid run_id.")

run_id = run_job.response.run_id
LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}")

return DatabricksInferenceRunResponse(job_run_id=run_id)

def delete_inst(self, inst_name: str) -> None:
"""Cleanup tasks required on the Databricks side to delete an institution."""
db_inst_name = databricksify_inst_name(inst_name)
Expand Down
88 changes: 83 additions & 5 deletions src/webapp/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from sqlalchemy import and_, update, or_
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from ..databricks import DatabricksControl, DatabricksInferenceRunRequest
from ..databricks import (
DatabricksControl,
DatabricksPDPInferenceRunRequest,
DatabricksCustomInferenceRunRequest,
)
from ..utilities import (
has_access_to_inst_or_err,
has_full_data_access_or_err,
Expand Down Expand Up @@ -138,6 +142,9 @@ class InferenceRunRequest(BaseModel):
# Note: is_pdp is kept for backward compatibility but is ignored.
# PDP status is derived from the institution's pdp_id field.
is_pdp: bool = False
# Custom schools inference parameters (required for custom schools, ignored for PDP)
config_file_name: str | None = None
features_table_name: str | None = None


# Model related operations. Or model specific data.
Expand Down Expand Up @@ -524,11 +531,82 @@ def trigger_inference_run(
+ str(len(inst_result)),
)
inst = inst_result[0][0]
# Check PDP status from institution's pdp_id (ignore req.is_pdp for backward compat)
if not inst.pdp_id:
# Determine institution type: PDP, Edvise, or Legacy/Custom
# There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy/Custom (legacy_id or none)
# Follows the same pattern as validation_helper in data.py
pdp_id = getattr(inst, "pdp_id", None)
edvise_id = getattr(inst, "edvise_id", None)
legacy_id = getattr(inst, "legacy_id", None)
# Defensive check: ensure mutual exclusivity (should not happen if validation works correctly)
if sum(bool(x) for x in (pdp_id, edvise_id, legacy_id)) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution configuration error: cannot have more than one of pdp_id, edvise_id, or legacy_id set",
)
is_pdp = bool(pdp_id)
is_edvise = bool(edvise_id)
# Legacy and custom are the same thing - both use custom inference pipeline
is_legacy_or_custom = not is_pdp and not is_edvise

# Legacy/Custom schools inference
if is_legacy_or_custom:
if not req.config_file_name or not req.features_table_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Legacy/Custom schools inference requires config_file_name and features_table_name.",
)
# For legacy/custom schools, we don't need batch validation (config and features table are used instead)
db_req = DatabricksCustomInferenceRunRequest(
inst_name=inst_result[0][0].name,
model_name=model_name,
config_file_name=req.config_file_name,
features_table_name=req.features_table_name,
gcp_external_bucket_name=get_external_bucket_name(inst_id),
email=cast(str, current_user.email),
)
try:
res = databricks_control.run_custom_inference(db_req)
except Exception as e:
tb = traceback.format_exc()
logging.error(f"Databricks run failure:\n{tb}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Databricks run_custom_inference error. Error = {str(e)}",
) from e
triggered_timestamp = datetime.now()
latest_model_version = databricks_control.fetch_model_version(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=inst_result[0][0].name,
model_name=model_name,
)
job = JobTable(
id=res.job_run_id,
triggered_at=triggered_timestamp,
created_by=str_to_uuid(current_user.user_id),
batch_name=f"{model_name}_{triggered_timestamp}", # Custom schools don't use batches
model_id=query_result[0][0].id,
output_valid=False,
model_version=latest_model_version.version,
model_run_id=latest_model_version.run_id,
)
local_session.get().add(job)
return {
"inst_id": inst_id,
"m_name": model_name,
"run_id": res.job_run_id,
"created_by": current_user.user_id,
"triggered_at": triggered_timestamp,
"batch_name": f"{model_name}_{triggered_timestamp}",
"output_valid": False,
"model_version": latest_model_version.version,
"model_run_id": latest_model_version.run_id,
}

# PDP inference (existing logic)
if not is_pdp:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Currently, only PDP inference is supported.",
detail="Currently, only PDP and Legacy/Custom schools inference are supported.",
)
query_result = (
local_session.get()
Expand Down Expand Up @@ -589,7 +667,7 @@ def trigger_inference_run(
detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}",
)
# Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines.
db_req = DatabricksInferenceRunRequest(
db_req = DatabricksPDPInferenceRunRequest(
inst_name=inst_result[0][0].name,
filepath_to_type=convert_files_to_dict(batch_result[0][0].files),
model_name=model_name,
Expand Down
Loading