Skip to content

Commit b32a290

Browse files
committed
Add support to pass in typed metrics & group by
1 parent 7518f75 commit b32a290

File tree

4 files changed

+82
-16
lines changed

4 files changed

+82
-16
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kind: Features
2+
body: Add support to pass in typed metrics and group by for query function
3+
time: 2025-03-20T15:59:41.744532-05:00

dbtsl/api/graphql/protocol.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dbtsl.api.graphql.util import render_query
88
from dbtsl.api.shared.query_params import (
99
AdhocQueryParametersStrict,
10+
OrderByGroupBy,
1011
OrderByMetric,
1112
QueryParameters,
1213
validate_query_parameters,
@@ -201,17 +202,20 @@ def get_query_request_variables(environment_id: int, params: QueryParameters) ->
201202
"""Get the GraphQL request variables for a given set of query parameters."""
202203
strict_params = validate_query_parameters(params) # type: ignore
203204

205+
order_by_vars = []
206+
for clause in strict_params.order_by or []:
207+
if isinstance(clause, OrderByMetric):
208+
order_by_vars.append({"metric": {"name": clause.name}, "descending": clause.descending})
209+
else:
210+
assert isinstance(clause, OrderByGroupBy)
211+
order_by_vars.append(
212+
{"groupBy": {"name": clause.name, "timeGranularity": clause.grain}, "descending": clause.descending}
213+
)
214+
204215
shared_vars = {
205216
"environmentId": environment_id,
206217
"where": [{"sql": sql} for sql in strict_params.where] if strict_params.where is not None else None,
207-
"orderBy": [
208-
{"metric": {"name": clause.name}, "descending": clause.descending}
209-
if isinstance(clause, OrderByMetric)
210-
else {"groupBy": {"name": clause.name, "timeGranularity": clause.grain}, "descending": clause.descending}
211-
for clause in strict_params.order_by
212-
]
213-
if strict_params.order_by is not None
214-
else None,
218+
"orderBy": order_by_vars if strict_params.order_by is not None else None,
215219
"limit": strict_params.limit,
216220
"readCache": strict_params.read_cache,
217221
}

dbtsl/api/shared/query_params.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ class OrderByMetric:
1010
descending: bool = False
1111

1212

13+
@dataclass(frozen=True)
14+
class GroupByParam:
15+
"""Spec for a group_by, i.e a dimension or an entity.
16+
17+
Not specifying a grain will defer the grain choice to the server.
18+
"""
19+
20+
name: str
21+
grain: Optional[str]
22+
1323
@dataclass(frozen=True)
1424
class OrderByGroupBy:
1525
"""Spec for ordering by a group_by, i.e a dimension or an entity.
@@ -21,7 +31,6 @@ class OrderByGroupBy:
2131
grain: Optional[str]
2232
descending: bool = False
2333

24-
2534
OrderBySpec = Union[OrderByMetric, OrderByGroupBy]
2635

2736

@@ -33,7 +42,7 @@ class QueryParameters(TypedDict, total=False):
3342

3443
saved_query: str
3544
metrics: List[str]
36-
group_by: List[str]
45+
group_by: List[Union[GroupByParam, str]]
3746
limit: int
3847
order_by: List[Union[OrderBySpec, str]]
3948
where: List[str]
@@ -44,8 +53,8 @@ class QueryParameters(TypedDict, total=False):
4453
class AdhocQueryParametersStrict:
4554
"""The parameters of an adhoc query, strictly validated."""
4655

47-
metrics: Optional[List[str]]
48-
group_by: Optional[List[str]]
56+
metrics: Optional[List[Union[str, OrderByMetric]]]
57+
group_by: Optional[List[Union[str, OrderByGroupBy]]]
4958
limit: Optional[int]
5059
order_by: Optional[List[OrderBySpec]]
5160
where: Optional[List[str]]
@@ -64,7 +73,9 @@ class SavedQueryQueryParametersStrict:
6473

6574

6675
def validate_order_by(
67-
known_metrics: List[str], known_group_bys: List[str], clause: Union[OrderBySpec, str]
76+
known_metrics: List[Union[str, OrderByMetric]],
77+
known_group_bys: List[Union[str, OrderByGroupBy]],
78+
clause: Union[OrderBySpec, str],
6879
) -> OrderBySpec:
6980
"""Validate an order by clause like `-metric_name`."""
7081
if isinstance(clause, OrderByMetric) or isinstance(clause, OrderByGroupBy):
@@ -74,10 +85,23 @@ def validate_order_by(
7485
if descending or clause.startswith("+"):
7586
clause = clause[1:]
7687

77-
if clause in known_metrics:
88+
normalized_known_metrics: list[str] = []
89+
for known_metric in known_metrics:
90+
if isinstance(known_metric, OrderByMetric):
91+
normalized_known_metrics.append(known_metric.name)
92+
else:
93+
normalized_known_metrics.append(known_metric)
94+
95+
if clause in normalized_known_metrics:
7896
return OrderByMetric(name=clause, descending=descending)
7997

80-
if clause in known_group_bys or clause == "metric_time":
98+
normalized_known_group_bys: list[str] = []
99+
for known_group_by in known_group_bys:
100+
if isinstance(known_group_by, OrderByGroupBy):
101+
normalized_known_group_bys.append(known_group_by.name)
102+
else:
103+
normalized_known_group_bys.append(known_group_by)
104+
if clause in normalized_known_group_bys or clause == "metric_time":
81105
return OrderByGroupBy(name=clause, descending=descending, grain=None)
82106

83107
# TODO: make this error less strict when server supports order_by type inference.

tests/query_test_cases.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List
22

33
from dbtsl import OrderByGroupBy
4-
from dbtsl.api.shared.query_params import QueryParameters
4+
from dbtsl.api.shared.query_params import OrderByMetric, QueryParameters
55

66
TEST_QUERIES: List[QueryParameters] = [
77
# ad hoc query, all parameters
@@ -38,4 +38,39 @@
3838
{
3939
"saved_query": "order_metrics",
4040
},
41+
# typed metric & group by
42+
{
43+
"metrics": [OrderByMetric(name="order_total", descending=False)],
44+
"group_by": [OrderByGroupBy(name="customer__customer_type", grain="day", descending=False)],
45+
},
46+
# multiple typed metrics
47+
{
48+
"metrics": [
49+
OrderByMetric(name="order_total", descending=True),
50+
OrderByMetric(name="order_count", descending=False),
51+
],
52+
},
53+
# multiple typed group by
54+
{
55+
"metrics": [OrderByMetric(name="order_total", descending=False)],
56+
"group_by": [
57+
OrderByGroupBy(name="customer__customer_type", grain="day", descending=True),
58+
OrderByGroupBy(name="order__status", grain="month", descending=False),
59+
],
60+
},
61+
# typed metrics with different grains
62+
{
63+
"metrics": [OrderByMetric(name="order_total", descending=True)],
64+
"group_by": [
65+
OrderByGroupBy(name="order__created_at", grain="day", descending=True),
66+
OrderByGroupBy(name="order__created_at", grain="month", descending=False),
67+
],
68+
},
69+
# typed metrics with where clause and limit
70+
{
71+
"metrics": [OrderByMetric(name="order_total", descending=True)],
72+
"group_by": [OrderByGroupBy(name="customer__customer_type", grain="day", descending=True)],
73+
"where": ["order_total > 1000"],
74+
"limit": 10,
75+
},
4176
]

0 commit comments

Comments
 (0)