Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion providers/databricks/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ PIP package Version required
``apache-airflow-providers-common-sql`` ``>=1.27.0``
``requests`` ``>=2.32.0,<3``
``databricks-sql-connector`` ``>=4.0.0``
``databricks-sqlalchemy`` ``>=1.0.2``
``aiohttp`` ``>=3.9.2,<4``
``mergedeep`` ``>=1.3.4``
``pandas`` ``>=2.1.2; python_version < "3.13"``
Expand Down
5 changes: 4 additions & 1 deletion providers/databricks/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ dependencies = [
"apache-airflow-providers-common-sql>=1.27.0",
"requests>=2.32.0,<3",
"databricks-sql-connector>=4.0.0",
"databricks-sqlalchemy>=1.0.2",
"aiohttp>=3.9.2, <4",
"mergedeep>=1.3.4",
'pandas>=2.1.2; python_version <"3.13"',
Expand Down Expand Up @@ -91,6 +90,9 @@ dependencies = [
"openlineage" = [
"apache-airflow-providers-openlineage>=2.3.0"
]
"sqlalchemy" = [
"databricks-sqlalchemy>=1.0.2",
]

[dependency-groups]
dev = [
Expand All @@ -107,6 +109,7 @@ dev = [
"apache-airflow-providers-microsoft-azure",
"apache-airflow-providers-common-sql[pandas,polars]",
"apache-airflow-providers-fab",
"apache-airflow-providers-databricks[sqlalchemy]",
]

# To build docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

from databricks import sql
from databricks.sql.types import Row
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
Expand All @@ -43,6 +43,7 @@

if TYPE_CHECKING:
from databricks.sql.client import Connection
from sqlalchemy.engine import URL

from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.openlineage.extractors import OperatorLineage
Expand Down Expand Up @@ -179,6 +180,14 @@ def sqlalchemy_url(self) -> URL:

:return: the extracted sqlalchemy.engine.URL object.
"""
try:
from sqlalchemy.engine import URL
except ImportError:
raise AirflowOptionalProviderFeatureException(
"sqlalchemy is required to generate the connection URL. "
"Install it with: pip install 'apache-airflow-providers-databricks[sqlalchemy]'"
)

url_query = {
"http_path": self._http_path,
"catalog": self.catalog,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote

from sqlalchemy import select

from airflow.exceptions import TaskInstanceNotFound
from airflow.exceptions import AirflowOptionalProviderFeatureException, TaskInstanceNotFound
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey, clear_task_instances
from airflow.providers.common.compat.sdk import (
Expand Down Expand Up @@ -75,6 +73,10 @@ def get_databricks_task_ids(
from flask_appbuilder import BaseView
from flask_appbuilder.api import expose

try:
from sqlalchemy import select
except ImportError:
select = None # type: ignore[assignment,misc]
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www import auth

Expand Down Expand Up @@ -147,6 +149,11 @@ def _get_dagrun(dag, run_id: str, session: Session) -> DagRun:
:param session: The SQLAlchemy session to use for the query. If None, uses the default session.
:return: The DagRun object associated with the specified DAG and run_id.
"""
if select is None:
raise AirflowOptionalProviderFeatureException(
"sqlalchemy is required for workflow repair functionality. "
"Install it with: pip install 'apache-airflow-providers-databricks[sqlalchemy]'"
)
if not session:
raise AirflowException("Session not provided.")

Expand All @@ -166,6 +173,11 @@ def _clear_task_instances(

@provide_session
def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance:
if select is None:
raise AirflowOptionalProviderFeatureException(
"sqlalchemy is required to get task instance. "
"Install it with: pip install 'apache-airflow-providers-databricks[sqlalchemy]'"
)
dag_id = operator.dag.dag_id
if hasattr(DagRun, "execution_date"): # Airflow 2.x.
dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg]
Expand Down
4 changes: 4 additions & 0 deletions scripts/in_container/install_airflow_and_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,13 @@ def find_provider_distributions(extension: str, selected_providers: list[str]) -
for candidate in candidates:
# https://github.com/apache/airflow/pull/49339
path_str = candidate.as_posix()
# Add optional extras that we test providers with
if "apache_airflow_providers_common_sql" in path_str:
console.print(f"[bright_blue]Adding [polars] extra to common.sql provider: {path_str}")
path_str += "[polars]"
if "apache_airflow_providers_databricks" in path_str:
console.print(f"[bright_blue]Adding [sqlalchemy] extra to databricks provider: {path_str}")
path_str += "[sqlalchemy]"
result.append(path_str)
return result

Expand Down
Loading