Skip to content

Commit a416e67

Browse files
authored
Add GroupByParam (#70)
1 parent c56ac1a commit a416e67

File tree

7 files changed

+107
-7
lines changed

7 files changed

+107
-7
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kind: Features
2+
body: Add support to pass in typed group by
3+
time: 2025-03-27T22:48:51.368818-05:00

dbtsl/api/adbc/protocol.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from dbtsl.api.shared.query_params import (
66
DimensionValuesQueryParameters,
7+
GroupByParam,
8+
GroupByType,
79
OrderByGroupBy,
810
OrderByMetric,
911
QueryParameters,
@@ -38,6 +40,17 @@ def _serialize_val(cls, val: Any) -> str:
3840
d += ".descending(True)"
3941
return d
4042

43+
if isinstance(val, GroupByParam):
44+
g: str = ""
45+
if val.type == GroupByType.DIMENSION:
46+
g = f'Dimension("{val.name}")'
47+
elif val.type == GroupByType.ENTITY:
48+
g = f'Entity("{val.name}")'
49+
if val.grain:
50+
grain_str = val.grain.lower()
51+
g += f'.grain("{grain_str}")'
52+
return g
53+
4154
return json.dumps(val)
4255

4356
@classmethod

dbtsl/api/graphql/protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,12 @@ def get_query_request_variables(environment_id: int, params: QueryParameters) ->
220220
return {
221221
"savedQuery": None,
222222
"metrics": [{"name": m} for m in strict_params.metrics] if strict_params.metrics is not None else None,
223-
"groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None,
223+
"groupBy": [
224+
{"name": g} if isinstance(g, str) else {"name": g.name, "timeGranularity": g.grain}
225+
for g in strict_params.group_by
226+
]
227+
if strict_params.group_by is not None
228+
else None,
224229
**shared_vars,
225230
}
226231

dbtsl/api/shared/query_params.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
11
from dataclasses import dataclass
2+
from enum import Enum
23
from typing import List, Optional, TypedDict, Union
34

45

6+
class GroupByType(Enum):
7+
DIMENSION = "dimension"
8+
ENTITY = "entity"
9+
10+
11+
@dataclass(frozen=True)
12+
class GroupByParam:
13+
"""Parameter for a group_by, i.e a dimension or an entity."""
14+
15+
name: str
16+
type: GroupByType
17+
grain: Optional[str]
18+
19+
520
@dataclass(frozen=True)
621
class OrderByMetric:
722
"""Spec for ordering by a metric."""
@@ -33,7 +48,7 @@ class QueryParameters(TypedDict, total=False):
3348

3449
saved_query: str
3550
metrics: List[str]
36-
group_by: List[str]
51+
group_by: List[Union[GroupByParam, str]]
3752
limit: int
3853
order_by: List[Union[OrderBySpec, str]]
3954
where: List[str]
@@ -45,7 +60,7 @@ class AdhocQueryParametersStrict:
4560
"""The parameters of an adhoc query, strictly validated."""
4661

4762
metrics: Optional[List[str]]
48-
group_by: Optional[List[str]]
63+
group_by: Optional[List[Union[GroupByParam, str]]]
4964
limit: Optional[int]
5065
order_by: Optional[List[OrderBySpec]]
5166
where: Optional[List[str]]
@@ -64,7 +79,9 @@ class SavedQueryQueryParametersStrict:
6479

6580

6681
def validate_order_by(
67-
known_metrics: List[str], known_group_bys: List[str], clause: Union[OrderBySpec, str]
82+
known_metrics: List[str],
83+
known_group_bys: List[Union[str, GroupByParam]],
84+
clause: Union[OrderBySpec, str],
6885
) -> OrderBySpec:
6986
"""Validate an order by clause like `-metric_name`."""
7087
if isinstance(clause, OrderByMetric) or isinstance(clause, OrderByGroupBy):
@@ -77,7 +94,11 @@ def validate_order_by(
7794
if clause in known_metrics:
7895
return OrderByMetric(name=clause, descending=descending)
7996

80-
if clause in known_group_bys or clause == "metric_time":
97+
normalized_known_group_bys = [
98+
known_group_by.name if isinstance(known_group_by, GroupByParam) else known_group_by
99+
for known_group_by in known_group_bys
100+
]
101+
if clause in normalized_known_group_bys or clause == "metric_time":
81102
return OrderByGroupBy(name=clause, descending=descending, grain=None)
82103

83104
# TODO: make this error less strict when server supports order_by type inference.

tests/api/adbc/test_protocol.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dbtsl.api.adbc.protocol import ADBCProtocol
2-
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric
2+
from dbtsl.api.shared.query_params import GroupByParam, GroupByType, OrderByGroupBy, OrderByMetric
33

44

55
def test_serialize_val_basic_values() -> None:
@@ -74,6 +74,23 @@ def test_get_query_sql_simple_query() -> None:
7474
assert sql == expected
7575

7676

77+
def test_get_query_sql_group_by_param() -> None:
78+
sql = ADBCProtocol.get_query_sql(
79+
params={
80+
"metrics": ["a", "b"],
81+
"group_by": [
82+
GroupByParam(name="c", type=GroupByType.DIMENSION, grain="day"),
83+
GroupByParam(name="d", type=GroupByType.ENTITY, grain="week"),
84+
],
85+
}
86+
)
87+
expected = (
88+
"SELECT * FROM {{ semantic_layer.query("
89+
'group_by=[Dimension("c").grain("day"),Entity("d").grain("week")],metrics=["a","b"],read_cache=True) }}'
90+
)
91+
assert sql == expected
92+
93+
7794
def test_get_query_sql_dimension_values_query() -> None:
7895
sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]})
7996
expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}'

tests/query_test_cases.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List
22

33
from dbtsl import OrderByGroupBy
4-
from dbtsl.api.shared.query_params import QueryParameters
4+
from dbtsl.api.shared.query_params import GroupByParam, GroupByType, QueryParameters
55

66
TEST_QUERIES: List[QueryParameters] = [
77
# ad hoc query, all parameters
@@ -38,4 +38,17 @@
3838
{
3939
"saved_query": "order_metrics",
4040
},
41+
# group by param object
42+
{
43+
"metrics": ["order_total"],
44+
"group_by": [GroupByParam(name="customer__customer_type", grain="month", type=GroupByType.DIMENSION)],
45+
},
46+
# multiple group by param objects
47+
{
48+
"metrics": ["order_total"],
49+
"group_by": [
50+
GroupByParam(name="customer__customer_type", grain="month", type=GroupByType.DIMENSION),
51+
GroupByParam(name="customer__customer_type", grain="week", type=GroupByType.DIMENSION),
52+
],
53+
},
4154
]

tests/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from dbtsl.api.graphql.util import normalize_query
1414
from dbtsl.api.shared.query_params import (
1515
AdhocQueryParametersStrict,
16+
GroupByParam,
17+
GroupByType,
1618
OrderByGroupBy,
1719
OrderByMetric,
1820
QueryParameters,
@@ -289,3 +291,29 @@ def test_validate_query_params_no_query() -> None:
289291
p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False}
290292
with pytest.raises(ValueError):
291293
validate_query_parameters(p)
294+
295+
296+
def test_validate_query_params_group_by_param_dimension() -> None:
297+
p: QueryParameters = {
298+
"group_by": [GroupByParam(name="a", grain="day", type=GroupByType.DIMENSION)],
299+
"order_by": ["a"],
300+
}
301+
r = validate_query_parameters(p)
302+
assert isinstance(r, AdhocQueryParametersStrict)
303+
assert r.group_by == [GroupByParam(name="a", grain="day", type=GroupByType.DIMENSION)]
304+
305+
306+
def test_validate_query_params_group_by_param_entity() -> None:
307+
p: QueryParameters = {"group_by": [GroupByParam(name="a", grain="day", type=GroupByType.ENTITY)], "order_by": ["a"]}
308+
r = validate_query_parameters(p)
309+
assert isinstance(r, AdhocQueryParametersStrict)
310+
assert r.group_by == [GroupByParam(name="a", grain="day", type=GroupByType.ENTITY)]
311+
312+
313+
def test_validate_missing_query_params_group_by_param() -> None:
314+
p: QueryParameters = {
315+
"group_by": [GroupByParam(name="b", grain="day", type=GroupByType.DIMENSION)],
316+
"order_by": ["a"],
317+
}
318+
with pytest.raises(ValueError):
319+
validate_query_parameters(p)

0 commit comments

Comments
 (0)