Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20250818-140222.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Add ability to add extra headers during a request
time: 2025-08-18T14:02:22.72837-04:00
6 changes: 4 additions & 2 deletions dbtsl/api/adbc/client/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from typing import AsyncIterator, Dict, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
Expand All @@ -18,6 +18,7 @@ def __init__(
environment_id: int,
auth_token: str,
url_format: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize the ADBC client.

Expand All @@ -29,8 +30,9 @@ def __init__(
into a full URL. If `None`, the default
`grpc+tls://{server_host}:443`
will be assumed.
extra_headers: extra headers to be sent with the request.
"""
super().__init__(server_host, environment_id, auth_token, url_format)
super().__init__(server_host, environment_id, auth_token, url_format, extra_headers=extra_headers)
self._loop = asyncio.get_running_loop()

@asynccontextmanager
Expand Down
15 changes: 13 additions & 2 deletions dbtsl/api/adbc/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextlib import AbstractContextManager
from typing import Dict, Generic, Optional, Protocol, TypeVar, Union

from adbc_driver_flightsql import DatabaseOptions
from adbc_driver_flightsql import DatabaseOptions, StatementOptions
from adbc_driver_flightsql.dbapi import Connection
from adbc_driver_flightsql.dbapi import connect as adbc_connect # pyright: ignore[reportUnknownVariableType]
from adbc_driver_manager import AdbcStatusCode, ProgrammingError
Expand Down Expand Up @@ -33,11 +33,13 @@ def __init__( # noqa: D107
environment_id: int,
auth_token: str,
url_format: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
url_format = url_format or self.DEFAULT_URL_FORMAT
self._conn_str = url_format.format(server_host=server_host)
self._environment_id = environment_id
self._auth_token = auth_token
self._extra_headers = extra_headers or {}

self._conn_unsafe: Union[Connection, None] = None

Expand All @@ -48,6 +50,7 @@ def _get_connection_context_manager(self) -> AbstractContextManager[Connection]:
DatabaseOptions.AUTHORIZATION_HEADER.value: f"Bearer {self._auth_token}",
f"{DatabaseOptions.RPC_CALL_HEADER_PREFIX.value}environmentid": str(self._environment_id),
**self._extra_db_kwargs(),
**{f"{StatementOptions.RPC_CALL_HEADER_PREFIX.value}{k}": v for k, v in self._extra_headers.items()},
},
)

Expand Down Expand Up @@ -88,13 +91,21 @@ def has_session(self) -> bool:

class ADBCClientFactory(Protocol, Generic[TClient]): # noqa: D101
@abstractmethod
def __call__(self, server_host: str, environment_id: int, auth_token: str, url_format: str) -> TClient:
def __call__(
self,
server_host: str,
environment_id: int,
auth_token: str,
url_format: str,
extra_headers: Optional[Dict[str, str]] = None,
) -> TClient:
"""Initialize the Semantic Layer client.

Args:
server_host: the Semantic Layer API host
environment_id: your dbt environment ID
auth_token: the API auth token
url_format: the URL format string to construct the final URL with
extra_headers: extra headers to be sent with the request.
"""
pass
6 changes: 4 additions & 2 deletions dbtsl/api/adbc/client/sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Iterator, Optional
from typing import Dict, Iterator, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
Expand All @@ -17,6 +17,7 @@ def __init__(
environment_id: int,
auth_token: str,
url_format: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize the ADBC client.

Expand All @@ -28,8 +29,9 @@ def __init__(
into a full URL. If `None`, the default
`grpc+tls://{server_host}:443`
will be assumed.
extra_headers: extra headers to be sent with the request.
"""
super().__init__(server_host, environment_id, auth_token, url_format)
super().__init__(server_host, environment_id, auth_token, url_format, extra_headers=extra_headers)

@contextmanager
def session(self) -> Iterator[Self]:
Expand Down
4 changes: 4 additions & 0 deletions dbtsl/api/graphql/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__( # noqa: D107
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
extra_headers: Optional[Dict[str, str]] = None,
):
self.environment_id = environment_id
self.lazy = lazy
Expand All @@ -83,6 +84,7 @@ def __init__( # noqa: D107
headers = {
"authorization": f"bearer {auth_token}",
**self._extra_headers(),
**(extra_headers or {}),
}
transport = self._create_transport(url=server_url, headers=headers)
self._gql = Client(transport=transport, execute_timeout=self.timeout.execute_timeout)
Expand Down Expand Up @@ -163,6 +165,7 @@ def __call__(
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
extra_headers: Optional[Dict[str, str]] = None,
) -> TClient:
"""Initialize the Semantic Layer client.

Expand All @@ -173,5 +176,6 @@ def __call__(
url_format: the URL format string to construct the final URL with
timeout: `TimeoutOptions` or total timeout
lazy: lazy load large fields
extra_headers: extra headers to be sent with the request.
"""
pass
5 changes: 4 additions & 1 deletion dbtsl/api/graphql/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
extra_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the metadata client.

Expand All @@ -56,7 +57,9 @@ def __init__(
NOTE: If `timeout` is a `TimeoutOptions`, the `tls_close_timeout` will not be used, since
`requests` does not support TLS termination timeouts.
"""
super().__init__(server_host, environment_id, auth_token, url_format, timeout, lazy=lazy)
super().__init__(
server_host, environment_id, auth_token, url_format, timeout, lazy=lazy, extra_headers=extra_headers
)

@override
def _create_transport(self, url: str, headers: Dict[str, str]) -> RequestsHTTPTransport:
Expand Down
5 changes: 4 additions & 1 deletion dbtsl/client/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Any, Generic, Optional, TypeVar, Union
from typing import Any, Dict, Generic, Optional, TypeVar, Union

import dbtsl.env as env
from dbtsl.api.adbc.client.base import ADBCClientFactory, BaseADBCClient
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize the Semantic Layer client.

Expand All @@ -66,12 +67,14 @@ def __init__(
url_format=env.GRAPHQL_URL_FORMAT,
timeout=timeout,
lazy=lazy,
extra_headers=extra_headers,
)
self._adbc = adbc_factory(
server_host=host,
environment_id=environment_id,
auth_token=auth_token,
url_format=env.ADBC_URL_FORMAT,
extra_headers=extra_headers,
)

@property
Expand Down
5 changes: 4 additions & 1 deletion dbtsl/client/sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Iterator, Optional, Union
from typing import Dict, Iterator, Optional, Union

from typing_extensions import Self

Expand Down Expand Up @@ -28,6 +28,7 @@
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize the Semantic Layer client.

Expand All @@ -37,15 +38,17 @@
host: the Semantic Layer API host
timeout: `TimeoutOptions` or total timeout for the underlying GraphQL client.
lazy: if true, nested metadata queries will be need to be explicitly populated on-demand.
extra_headers: extra headers to be sent with the request.
"""
super().__init__(
environment_id=environment_id,
auth_token=auth_token,
host=host,
gql_factory=SyncGraphQLClient,

Check failure on line 47 in dbtsl/client/sync.py

View workflow job for this annotation

GitHub Actions / code-quality

Argument of type "type[SyncGraphQLClient]" cannot be assigned to parameter "gql_factory" of type "GraphQLClientFactory[SyncGraphQLClient]" in function "__init__"   Type "type[SyncGraphQLClient]" is not assignable to type "(server_host: str, environment_id: int, auth_token: str, url_format: str | None = None, timeout: TimeoutOptions | float | int | None = None, *, lazy: bool, extra_headers: Dict[str, str] | None = None) -> SyncGraphQLClient"     Missing keyword parameter "extra_headers" (reportArgumentType)
adbc_factory=SyncADBCClient,
timeout=timeout,
lazy=lazy,
extra_headers=extra_headers,
)

@contextmanager
Expand Down
Loading