diff --git a/src/webapp/databricks.py b/src/webapp/databricks.py index 5820415e..4ba1fe16 100644 --- a/src/webapp/databricks.py +++ b/src/webapp/databricks.py @@ -36,10 +36,11 @@ # The name of the deployed pipeline in Databricks. Must match directly. PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline" +CUSTOM_INFERENCE_JOB_NAME = "edvise_github_sourced_custom_inference_pipeline" -class DatabricksInferenceRunRequest(BaseModel): - """Databricks parameters for an inference run.""" +class DatabricksPDPInferenceRunRequest(BaseModel): + """Databricks parameters for a PDP inference run.""" inst_name: str # Note that the following should be the filepath. @@ -50,6 +51,18 @@ class DatabricksInferenceRunRequest(BaseModel): gcp_external_bucket_name: str +class DatabricksCustomInferenceRunRequest(BaseModel): + """Databricks parameters for a custom schools inference run.""" + + inst_name: str + model_name: str + config_file_name: str + features_table_name: str + # The email where notifications will get sent. + email: str + gcp_external_bucket_name: str + + class DatabricksInferenceRunResponse(BaseModel): """Databricks parameters for an inference run.""" @@ -186,7 +199,7 @@ def setup_new_inst(self, inst_name: str) -> None: # E.g. there is one PDP inference pipeline, so one PDP inference function here. def run_pdp_inference( - self, req: DatabricksInferenceRunRequest + self, req: DatabricksPDPInferenceRunRequest ) -> DatabricksInferenceRunResponse: """Triggers PDP inference Databricks run.""" LOGGER.info(f"Running PDP inference for institution: {req.inst_name}") @@ -264,6 +277,73 @@ def run_pdp_inference( return DatabricksInferenceRunResponse(job_run_id=run_id) + def run_custom_inference( + self, req: DatabricksCustomInferenceRunRequest + ) -> DatabricksInferenceRunResponse: + """Triggers custom schools inference Databricks run.""" + LOGGER.info(f"Running custom inference for institution: {req.inst_name}") + try: + w = WorkspaceClient( + host=databricks_vars["DATABRICKS_HOST_URL"], + google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + LOGGER.info("Successfully created Databricks WorkspaceClient.") + except Exception as e: + LOGGER.exception( + "Failed to create Databricks WorkspaceClient with host: %s and service account: %s", + databricks_vars["DATABRICKS_HOST_URL"], + gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"], + ) + raise ValueError( + f"run_custom_inference(): Workspace client initialization failed: {e}" + ) + + db_inst_name = databricksify_inst_name(req.inst_name) + pipeline_type = CUSTOM_INFERENCE_JOB_NAME + + try: + job = next(w.jobs.list(name=pipeline_type), None) + if not job or job.job_id is None: + raise ValueError( + f"run_custom_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'." + ) + job_id = job.job_id + LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}") + except Exception as e: + LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.") + raise ValueError(f"run_custom_inference(): Failed to find job: {e}") + + try: + run_job: Any = w.jobs.run_now( + job_id, + job_parameters={ + "databricks_institution_name": db_inst_name, + "DB_workspace": databricks_vars[ + "DATABRICKS_WORKSPACE" + ], + "model_name": req.model_name, + "config_file_name": req.config_file_name, + "features_table_name": req.features_table_name, + "gcp_bucket_name": req.gcp_external_bucket_name, + "datakind_notification_email": req.email, + "DK_CC_EMAIL": req.email, + }, + ) + LOGGER.info( + f"Successfully triggered job run. Run ID: {run_job.response.run_id}" + ) + except Exception as e: + LOGGER.exception("Failed to run the custom inference job.") + raise ValueError(f"run_custom_inference(): Job could not be run: {e}") + + if not run_job.response or run_job.response.run_id is None: + raise ValueError("run_custom_inference(): Job did not return a valid run_id.") + + run_id = run_job.response.run_id + LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}") + + return DatabricksInferenceRunResponse(job_run_id=run_id) + def delete_inst(self, inst_name: str) -> None: """Cleanup tasks required on the Databricks side to delete an institution.""" db_inst_name = databricksify_inst_name(inst_name) diff --git a/src/webapp/routers/models.py b/src/webapp/routers/models.py index 02be74ae..9aa8606f 100644 --- a/src/webapp/routers/models.py +++ b/src/webapp/routers/models.py @@ -8,7 +8,11 @@ from sqlalchemy import and_, update, or_ from sqlalchemy.orm import Session from sqlalchemy.future import select -from ..databricks import DatabricksControl, DatabricksInferenceRunRequest +from ..databricks import ( + DatabricksControl, + DatabricksPDPInferenceRunRequest, + DatabricksCustomInferenceRunRequest, +) from ..utilities import ( has_access_to_inst_or_err, has_full_data_access_or_err, @@ -138,6 +142,9 @@ class InferenceRunRequest(BaseModel): # Note: is_pdp is kept for backward compatibility but is ignored. # PDP status is derived from the institution's pdp_id field. is_pdp: bool = False + # Custom schools inference parameters (required for custom schools, ignored for PDP) + config_file_name: str | None = None + features_table_name: str | None = None # Model related operations. Or model specific data. @@ -524,11 +531,82 @@ def trigger_inference_run( + str(len(inst_result)), ) inst = inst_result[0][0] - # Check PDP status from institution's pdp_id (ignore req.is_pdp for backward compat) - if not inst.pdp_id: + # Determine institution type: PDP, Edvise, or Legacy/Custom + # There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy/Custom (legacy_id or none) + # Follows the same pattern as validation_helper in data.py + pdp_id = getattr(inst, "pdp_id", None) + edvise_id = getattr(inst, "edvise_id", None) + legacy_id = getattr(inst, "legacy_id", None) + # Defensive check: ensure mutual exclusivity (should not happen if validation works correctly) + if sum(bool(x) for x in (pdp_id, edvise_id, legacy_id)) > 1: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Institution configuration error: cannot have more than one of pdp_id, edvise_id, or legacy_id set", + ) + is_pdp = bool(pdp_id) + is_edvise = bool(edvise_id) + # Legacy and custom are the same thing - both use custom inference pipeline + is_legacy_or_custom = not is_pdp and not is_edvise + + # Legacy/Custom schools inference + if is_legacy_or_custom: + if not req.config_file_name or not req.features_table_name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Legacy/Custom schools inference requires config_file_name and features_table_name.", + ) + # For legacy/custom schools, we don't need batch validation (config and features table are used instead) + db_req = DatabricksCustomInferenceRunRequest( + inst_name=inst_result[0][0].name, + model_name=model_name, + config_file_name=req.config_file_name, + features_table_name=req.features_table_name, + gcp_external_bucket_name=get_external_bucket_name(inst_id), + email=cast(str, current_user.email), + ) + try: + res = databricks_control.run_custom_inference(db_req) + except Exception as e: + tb = traceback.format_exc() + logging.error(f"Databricks run failure:\n{tb}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Databricks run_custom_inference error. Error = {str(e)}", + ) from e + triggered_timestamp = datetime.now() + latest_model_version = databricks_control.fetch_model_version( + catalog_name=str(env_vars["CATALOG_NAME"]), + inst_name=inst_result[0][0].name, + model_name=model_name, + ) + job = JobTable( + id=res.job_run_id, + triggered_at=triggered_timestamp, + created_by=str_to_uuid(current_user.user_id), + batch_name=f"{model_name}_{triggered_timestamp}", # Custom schools don't use batches + model_id=query_result[0][0].id, + output_valid=False, + model_version=latest_model_version.version, + model_run_id=latest_model_version.run_id, + ) + local_session.get().add(job) + return { + "inst_id": inst_id, + "m_name": model_name, + "run_id": res.job_run_id, + "created_by": current_user.user_id, + "triggered_at": triggered_timestamp, + "batch_name": f"{model_name}_{triggered_timestamp}", + "output_valid": False, + "model_version": latest_model_version.version, + "model_run_id": latest_model_version.run_id, + } + + # PDP inference (existing logic) + if not is_pdp: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Currently, only PDP inference is supported.", + detail="Currently, only PDP and Legacy/Custom schools inference are supported.", ) query_result = ( local_session.get() @@ -589,7 +667,7 @@ def trigger_inference_run( detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}", ) # Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines. - db_req = DatabricksInferenceRunRequest( + db_req = DatabricksPDPInferenceRunRequest( inst_name=inst_result[0][0].name, filepath_to_type=convert_files_to_dict(batch_result[0][0].files), model_name=model_name,