Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20250414-170939.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Add `lazy` parameter to clients which allows lazy loading of certain model fields.
time: 2025-04-14T17:09:39.64588+02:00
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ arrow_table = client.query(...)
polars_df = pl.from_arrow(arrow_table)
```

### Lazy loading

By default, the SDK will eagerly request for lists of nested objects. For example, in the list of `Metric` returned by `client.metrics()`, each metric will contain the list of its dimensions, entities and measures. This is convenient in most cases, but can make your returned data really large in case your project is really large, which can slow things down.

It is possible to set the client to `lazy=True`, which will make it skip populating nested object lists unless you explicitly load ask for it on a per-model basis. Check our [lazy loading example](./examples/list_metrics_lazy_sync.py) to learn more.

### More examples

Check out our [usage examples](./examples/) to learn more.
Expand Down
11 changes: 8 additions & 3 deletions dbtsl/api/graphql/client/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(
auth_token: str,
url_format: Optional[str] = None,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
):
"""Initialize the metadata client.

Expand All @@ -60,12 +62,13 @@ def __init__(
into a full URL. If `None`, the default `https://{server_host}/api/graphql`
will be assumed.
timeout: TimeoutOptions or total timeout (in seconds) for all GraphQL requests.
lazy: Whether to lazy load large subfields

NOTE: If `timeout` is a `TimeoutOptions`, the `connect_timeout` will not be used, due to
limitations of `gql`'s `aiohttp` transport.
See: https://github.com/graphql-python/gql/blob/b066e8944b0da0a4bbac6c31f43e5c3c7772cd51/gql/transport/aiohttp.py#L110
"""
super().__init__(server_host, environment_id, auth_token, url_format, timeout)
super().__init__(server_host, environment_id, auth_token, url_format, timeout, lazy=lazy)

@override
def _create_transport(self, url: str, headers: Dict[str, str]) -> AIOHTTPTransport:
Expand Down Expand Up @@ -97,7 +100,7 @@ async def session(self) -> AsyncIterator[Self]:

async def _run(self, op: ProtocolOperation[TVariables, TResponse], raw_variables: TVariables) -> TResponse:
"""Run a `ProtocolOperation`."""
raw_query = op.get_request_text()
raw_query = op.get_request_text(lazy=self.lazy)
variables = op.get_request_variables(environment_id=self.environment_id, variables=raw_variables)
gql_query = gql(raw_query)

Expand All @@ -114,7 +117,9 @@ async def _run(self, op: ProtocolOperation[TVariables, TResponse], raw_variables
except Exception as err:
raise self._refine_err(err)

return op.parse_response(res)
resp = op.parse_response(res)
self._attach_self_to_parsed_response(resp)
return resp

async def _poll_until_complete(
self,
Expand Down
6 changes: 4 additions & 2 deletions dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ from typing_extensions import AsyncIterator, Unpack, overload

from dbtsl.api.shared.query_params import GroupByParam, OrderByGroupBy, OrderByMetric, QueryParameters
from dbtsl.models import (
AsyncMetric,
Dimension,
Entity,
Measure,
Metric,
SavedQuery,
)
from dbtsl.timeout import TimeoutOptions
Expand All @@ -24,11 +24,13 @@ class AsyncGraphQLClient:
auth_token: str,
url_format: Optional[str] = None,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
) -> None: ...
def session(self) -> AbstractAsyncContextManager[AsyncIterator[Self]]: ...
@property
def has_session(self) -> bool: ...
async def metrics(self) -> List[Metric]:
async def metrics(self) -> List[AsyncMetric]:
"""Get a list of all available metrics."""
...

Expand Down
19 changes: 19 additions & 0 deletions dbtsl/api/graphql/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from dbtsl.backoff import ExponentialBackoff
from dbtsl.error import AuthError
from dbtsl.models.base import GraphQLFragmentMixin
from dbtsl.timeout import TimeoutOptions

TTransport = TypeVar("TTransport", Transport, AsyncTransport)
Expand Down Expand Up @@ -59,8 +60,11 @@ def __init__( # noqa: D107
auth_token: str,
url_format: Optional[str] = None,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
):
self.environment_id = environment_id
self.lazy = lazy

url_format = url_format or self.DEFAULT_URL_FORMAT
server_url = url_format.format(server_host=server_host)
Expand Down Expand Up @@ -101,6 +105,18 @@ def _refine_err(self, err: Exception) -> Exception:

return err

def _attach_self_to_parsed_response(self, resp: object) -> None:
# NOTE: we're setting the _client_unchecked here instead of making a public property
# because we don't want end-users to be aware of this. You can consider _client_unchecked
# as public to the module but not to end users
if isinstance(resp, GraphQLFragmentMixin):
resp._client_unchecked = self # type: ignore
return

if isinstance(resp, list):
for v in resp: # pyright: ignore[reportUnknownVariableType]
self._attach_self_to_parsed_response(v) # pyright: ignore[reportUnknownArgumentType]

@property
def _gql_session(self) -> TSession:
"""Safe accessor to `_gql_session_unsafe`.
Expand Down Expand Up @@ -145,6 +161,8 @@ def __call__(
auth_token: str,
url_format: Optional[str] = None,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
) -> TClient:
"""Initialize the Semantic Layer client.

Expand All @@ -154,5 +172,6 @@ def __call__(
auth_token: the API auth token
url_format: the URL format string to construct the final URL with
timeout: `TimeoutOptions` or total timeout
lazy: lazy load large fields
"""
pass
11 changes: 8 additions & 3 deletions dbtsl/api/graphql/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
auth_token: str,
url_format: Optional[str] = None,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
):
"""Initialize the metadata client.

Expand All @@ -49,11 +51,12 @@ def __init__(
into a full URL. If `None`, the default `https://{server_host}/api/graphql`
will be assumed.
timeout: TimeoutOptions or total timeout (in seconds) for all GraphQL requests.
lazy: Whether to lazy load large subfields

NOTE: If `timeout` is a `TimeoutOptions`, the `tls_close_timeout` will not be used, since
`requests` does not support TLS termination timeouts.
"""
super().__init__(server_host, environment_id, auth_token, url_format, timeout)
super().__init__(server_host, environment_id, auth_token, url_format, timeout, lazy=lazy)

@override
def _create_transport(self, url: str, headers: Dict[str, str]) -> RequestsHTTPTransport:
Expand Down Expand Up @@ -85,7 +88,7 @@ def session(self) -> Iterator[Self]:

def _run(self, op: ProtocolOperation[TVariables, TResponse], raw_variables: TVariables) -> TResponse:
"""Run a `ProtocolOperation`."""
raw_query = op.get_request_text()
raw_query = op.get_request_text(lazy=self.lazy)
variables = op.get_request_variables(environment_id=self.environment_id, variables=raw_variables)
gql_query = gql(raw_query)

Expand All @@ -98,7 +101,9 @@ def _run(self, op: ProtocolOperation[TVariables, TResponse], raw_variables: TVar
except Exception as err:
raise self._refine_err(err)

return op.parse_response(res)
resp = op.parse_response(res)
self._attach_self_to_parsed_response(resp)
return resp

def _poll_until_complete(
self,
Expand Down
6 changes: 4 additions & 2 deletions dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ from dbtsl.models import (
Dimension,
Entity,
Measure,
Metric,
SavedQuery,
SyncMetric,
)
from dbtsl.timeout import TimeoutOptions

Expand All @@ -24,11 +24,13 @@ class SyncGraphQLClient:
auth_token: str,
url_format: Optional[str] = None,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool,
) -> None: ...
def session(self) -> AbstractContextManager[Iterator[Self]]: ...
@property
def has_session(self) -> bool: ...
def metrics(self) -> List[Metric]:
def metrics(self) -> List[SyncMetric]:
"""Get a list of all available metrics."""
...

Expand Down
30 changes: 15 additions & 15 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ProtocolOperation(Generic[TVariables, TResponse], ABC):
"""Base class for GraphQL API operations."""

@abstractmethod
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
"""Get the GraphQL request text."""
raise NotImplementedError()

Expand All @@ -71,15 +71,15 @@ class ListMetricsOperation(ProtocolOperation[EmptyVariables, List[Metric]]):
"""List all available metrics in available in the Semantic Layer."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
query getMetrics($environmentId: BigInt!) {
metrics(environmentId: $environmentId) {
...&fragment
}
}
"""
return render_query(query, Metric.gql_fragments())
return render_query(query, Metric.gql_fragments(lazy=lazy))

@override
def get_request_variables(self, environment_id: int, variables: EmptyVariables) -> Dict[str, Any]:
Expand All @@ -100,15 +100,15 @@ class ListDimensionsOperation(ProtocolOperation[ListEntitiesOperationVariables,
"""List all dimensions for a given set of metrics."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
query getDimensions($environmentId: BigInt!, $metrics: [MetricInput!]!) {
dimensions(environmentId: $environmentId, metrics: $metrics) {
...&fragment
}
}
"""
return render_query(query, Dimension.gql_fragments())
return render_query(query, Dimension.gql_fragments(lazy=lazy))

@override
def get_request_variables(self, environment_id: int, variables: ListEntitiesOperationVariables) -> Dict[str, Any]:
Expand All @@ -126,15 +126,15 @@ class ListMeasuresOperation(ProtocolOperation[ListEntitiesOperationVariables, Li
"""List all measures for a given set of metrics."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
query getMeasures($environmentId: BigInt!, $metrics: [MetricInput!]!) {
measures(environmentId: $environmentId, metrics: $metrics) {
...&fragment
}
}
"""
return render_query(query, Measure.gql_fragments())
return render_query(query, Measure.gql_fragments(lazy=lazy))

@override
def get_request_variables(self, environment_id: int, variables: ListEntitiesOperationVariables) -> Dict[str, Any]:
Expand All @@ -152,15 +152,15 @@ class ListEntitiesOperation(ProtocolOperation[ListEntitiesOperationVariables, Li
"""List all entities for a given set of metrics."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
query getEntities($environmentId: BigInt!, $metrics: [MetricInput!]!) {
entities(environmentId: $environmentId, metrics: $metrics) {
...&fragment
}
}
"""
return render_query(query, Entity.gql_fragments())
return render_query(query, Entity.gql_fragments(lazy=lazy))

@override
def get_request_variables(self, environment_id: int, variables: ListEntitiesOperationVariables) -> Dict[str, Any]:
Expand All @@ -178,15 +178,15 @@ class ListSavedQueriesOperation(ProtocolOperation[EmptyVariables, List[SavedQuer
"""List all saved queries."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
query getSavedQueries($environmentId: BigInt!) {
savedQueries(environmentId: $environmentId) {
...&fragment
}
}
"""
return render_query(query, SavedQuery.gql_fragments())
return render_query(query, SavedQuery.gql_fragments(lazy=lazy))

@override
def get_request_variables(self, environment_id: int, variables: EmptyVariables) -> Dict[str, Any]:
Expand Down Expand Up @@ -242,7 +242,7 @@ class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]):
"""Create a query that will be processed asynchronously."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
mutation createQuery(
$environmentId: BigInt!,
Expand Down Expand Up @@ -290,7 +290,7 @@ class GetQueryResultOperation(ProtocolOperation[GetQueryResultVariables, QueryRe
"""Get the results of a query that was already created."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
query getQueryResults(
$environmentId: BigInt!,
Expand All @@ -302,7 +302,7 @@ def get_request_text(self) -> str:
}
}
"""
return render_query(query, QueryResult.gql_fragments())
return render_query(query, QueryResult.gql_fragments(lazy=lazy))

@override
def get_request_variables(self, environment_id: int, variables: GetQueryResultVariables) -> Dict[str, Any]:
Expand All @@ -321,7 +321,7 @@ class CompileSqlOperation(ProtocolOperation[QueryParameters, str]):
"""Get the compiled SQL that would be sent to the warehouse by a query."""

@override
def get_request_text(self) -> str:
def get_request_text(self, *, lazy: bool) -> str:
query = """
mutation compileSql(
$environmentId: BigInt!,
Expand Down
4 changes: 4 additions & 0 deletions dbtsl/client/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
auth_token: str,
host: str,
timeout: Optional[Union[TimeoutOptions, float, int]] = None,
*,
lazy: bool = False,
) -> None:
"""Initialize the Semantic Layer client.

Expand All @@ -34,6 +36,7 @@ def __init__(
auth_token: the API auth token
host: the Semantic Layer API host
timeout: `TimeoutOptions` or total timeout for the underlying GraphQL client.
lazy: if true, nested metadata queries will be need to be explicitly populated on-demand.
"""
super().__init__(
environment_id=environment_id,
Expand All @@ -42,6 +45,7 @@ def __init__(
gql_factory=AsyncGraphQLClient,
adbc_factory=AsyncADBCClient,
timeout=timeout,
lazy=lazy,
)

@asynccontextmanager
Expand Down
Loading
Loading