Skip to content

Commit 6770a8d

Browse files
committed
Add GroupByParam
1 parent b239832 commit 6770a8d

File tree

4 files changed

+50
-6
lines changed

4 files changed

+50
-6
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kind: Features
2+
body: Add support to pass in typed group by
3+
time: 2025-03-27T22:48:51.368818-05:00

dbtsl/api/shared/query_params.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
from typing import List, Optional, TypedDict, Union
33

44

5+
@dataclass(frozen=True)
6+
class GroupByParam:
7+
"""Parameter for a group_by, i.e a dimension or an entity."""
8+
9+
name: str
10+
grain: Optional[str]
11+
12+
513
@dataclass(frozen=True)
614
class OrderByMetric:
715
"""Spec for ordering by a metric."""
@@ -33,7 +41,7 @@ class QueryParameters(TypedDict, total=False):
3341

3442
saved_query: str
3543
metrics: List[str]
36-
group_by: List[str]
44+
group_by: List[Union[GroupByParam, str]]
3745
limit: int
3846
order_by: List[Union[OrderBySpec, str]]
3947
where: List[str]
@@ -45,7 +53,7 @@ class AdhocQueryParametersStrict:
4553
"""The parameters of an adhoc query, strictly validated."""
4654

4755
metrics: Optional[List[str]]
48-
group_by: Optional[List[str]]
56+
group_by: Optional[List[Union[GroupByParam, str]]]
4957
limit: Optional[int]
5058
order_by: Optional[List[OrderBySpec]]
5159
where: Optional[List[str]]
@@ -64,7 +72,9 @@ class SavedQueryQueryParametersStrict:
6472

6573

6674
def validate_order_by(
67-
known_metrics: List[str], known_group_bys: List[str], clause: Union[OrderBySpec, str]
75+
known_metrics: List[str],
76+
known_group_bys: List[Union[str, GroupByParam]],
77+
clause: Union[OrderBySpec, str],
6878
) -> OrderBySpec:
6979
"""Validate an order by clause like `-metric_name`."""
7080
if isinstance(clause, OrderByMetric) or isinstance(clause, OrderByGroupBy):
@@ -77,7 +87,11 @@ def validate_order_by(
7787
if clause in known_metrics:
7888
return OrderByMetric(name=clause, descending=descending)
7989

80-
if clause in known_group_bys or clause == "metric_time":
90+
normalized_known_group_bys = [
91+
known_group_by.name if isinstance(known_group_by, GroupByParam) else known_group_by
92+
for known_group_by in known_group_bys
93+
]
94+
if clause in normalized_known_group_bys or clause == "metric_time":
8195
return OrderByGroupBy(name=clause, descending=descending, grain=None)
8296

8397
# TODO: make this error less strict when server supports order_by type inference.

tests/api/adbc/test_protocol.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dbtsl.api.adbc.protocol import ADBCProtocol
2-
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric
2+
from dbtsl.api.shared.query_params import GroupByParam, OrderByGroupBy, OrderByMetric
33

44

55
def test_serialize_val_basic_values() -> None:
@@ -74,6 +74,23 @@ def test_get_query_sql_simple_query() -> None:
7474
assert sql == expected
7575

7676

77+
def test_get_query_sql_group_by_param() -> None:
78+
sql = ADBCProtocol.get_query_sql(
79+
params={
80+
"metrics": ["a", "b"],
81+
"group_by": [
82+
GroupByParam(name="c", grain="day"),
83+
GroupByParam(name="d", grain="week"),
84+
],
85+
}
86+
)
87+
expected = (
88+
'SELECT * FROM {{ semantic_layer.query(metrics=["a","b"],'
89+
'group_by=[Dimension("c").grain("day"),Dimension("d").grain("week")],read_cache=True) }}'
90+
)
91+
assert sql == expected
92+
93+
7794
def test_get_query_sql_dimension_values_query() -> None:
7895
sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]})
7996
expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}'

tests/query_test_cases.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List
22

33
from dbtsl import OrderByGroupBy
4-
from dbtsl.api.shared.query_params import QueryParameters
4+
from dbtsl.api.shared.query_params import GroupByParam, QueryParameters
55

66
TEST_QUERIES: List[QueryParameters] = [
77
# ad hoc query, all parameters
@@ -38,4 +38,14 @@
3838
{
3939
"saved_query": "order_metrics",
4040
},
41+
# group by param object
42+
{"metrics": ["order_total"], "group_by": [GroupByParam(name="order_date", grain="month")]},
43+
# multiple group by param objects
44+
{
45+
"metrics": ["order_total"],
46+
"group_by": [
47+
GroupByParam(name="order_date", grain="month"),
48+
GroupByParam(name="fulfillment_date", grain="week"),
49+
],
50+
},
4151
]

0 commit comments

Comments
 (0)