diff --git a/providers/databricks/docs/index.rst b/providers/databricks/docs/index.rst index faea8d2b05d02..38708434bdb65 100644 --- a/providers/databricks/docs/index.rst +++ b/providers/databricks/docs/index.rst @@ -106,7 +106,6 @@ PIP package Version required ``apache-airflow-providers-common-sql`` ``>=1.27.0`` ``requests`` ``>=2.32.0,<3`` ``databricks-sql-connector`` ``>=4.0.0`` -``databricks-sqlalchemy`` ``>=1.0.2`` ``aiohttp`` ``>=3.9.2,<4`` ``mergedeep`` ``>=1.3.4`` ``pandas`` ``>=2.1.2; python_version < "3.13"`` diff --git a/providers/databricks/pyproject.toml b/providers/databricks/pyproject.toml index 2f76d840ce2b6..f800eff99793a 100644 --- a/providers/databricks/pyproject.toml +++ b/providers/databricks/pyproject.toml @@ -63,7 +63,6 @@ dependencies = [ "apache-airflow-providers-common-sql>=1.27.0", "requests>=2.32.0,<3", "databricks-sql-connector>=4.0.0", - "databricks-sqlalchemy>=1.0.2", "aiohttp>=3.9.2, <4", "mergedeep>=1.3.4", 'pandas>=2.1.2; python_version <"3.13"', @@ -91,6 +90,9 @@ dependencies = [ "openlineage" = [ "apache-airflow-providers-openlineage>=2.3.0" ] +"sqlalchemy" = [ + "databricks-sqlalchemy>=1.0.2", +] [dependency-groups] dev = [ @@ -107,6 +109,7 @@ dev = [ "apache-airflow-providers-microsoft-azure", "apache-airflow-providers-common-sql[pandas,polars]", "apache-airflow-providers-fab", + "apache-airflow-providers-databricks[sqlalchemy]", ] # To build docs: diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py index 2d05946ec0891..2684b00ac2841 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -32,8 +32,8 @@ from databricks import sql from databricks.sql.types import Row -from sqlalchemy.engine import URL +from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.common.sql.hooks.handlers import return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -43,6 +43,7 @@ if TYPE_CHECKING: from databricks.sql.client import Connection + from sqlalchemy.engine import URL from airflow.models.connection import Connection as AirflowConnection from airflow.providers.openlineage.extractors import OperatorLineage @@ -179,6 +180,14 @@ def sqlalchemy_url(self) -> URL: :return: the extracted sqlalchemy.engine.URL object. """ + try: + from sqlalchemy.engine import URL + except ImportError: + raise AirflowOptionalProviderFeatureException( + "sqlalchemy is required to generate the connection URL. " + "Install it with: pip install 'apache-airflow-providers-databricks[sqlalchemy]'" + ) + url_query = { "http_path": self._http_path, "catalog": self.catalog, diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index dd595c49b7f7d..94fa56a863e9a 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -20,9 +20,7 @@ from typing import TYPE_CHECKING, Any from urllib.parse import unquote -from sqlalchemy import select - -from airflow.exceptions import TaskInstanceNotFound +from airflow.exceptions import AirflowOptionalProviderFeatureException, TaskInstanceNotFound from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstanceKey, clear_task_instances from airflow.providers.common.compat.sdk import ( @@ -75,6 +73,10 @@ def get_databricks_task_ids( from flask_appbuilder import BaseView from flask_appbuilder.api import expose + try: + from sqlalchemy import select + except ImportError: + select = None # type: ignore[assignment,misc] from airflow.utils.session import NEW_SESSION, provide_session from airflow.www import auth @@ -147,6 +149,11 @@ def _get_dagrun(dag, run_id: str, session: Session) -> DagRun: :param session: The SQLAlchemy session to use for the query. If None, uses the default session. :return: The DagRun object associated with the specified DAG and run_id. """ + if select is None: + raise AirflowOptionalProviderFeatureException( + "sqlalchemy is required for workflow repair functionality. " + "Install it with: pip install 'apache-airflow-providers-databricks[sqlalchemy]'" + ) if not session: raise AirflowException("Session not provided.") @@ -166,6 +173,11 @@ def _clear_task_instances( @provide_session def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance: + if select is None: + raise AirflowOptionalProviderFeatureException( + "sqlalchemy is required to get task instance. " + "Install it with: pip install 'apache-airflow-providers-databricks[sqlalchemy]'" + ) dag_id = operator.dag.dag_id if hasattr(DagRun, "execution_date"): # Airflow 2.x. dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg] diff --git a/scripts/in_container/install_airflow_and_providers.py b/scripts/in_container/install_airflow_and_providers.py index 09b349d0110c5..d93bd9d925ac6 100755 --- a/scripts/in_container/install_airflow_and_providers.py +++ b/scripts/in_container/install_airflow_and_providers.py @@ -128,9 +128,13 @@ def find_provider_distributions(extension: str, selected_providers: list[str]) - for candidate in candidates: # https://github.com/apache/airflow/pull/49339 path_str = candidate.as_posix() + # Add optional extras that we test providers with if "apache_airflow_providers_common_sql" in path_str: console.print(f"[bright_blue]Adding [polars] extra to common.sql provider: {path_str}") path_str += "[polars]" + if "apache_airflow_providers_databricks" in path_str: + console.print(f"[bright_blue]Adding [sqlalchemy] extra to databricks provider: {path_str}") + path_str += "[sqlalchemy]" result.append(path_str) return result