Skip to content

Add support to pass in typed group by #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20250327-224851.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Add support to pass in typed group by
time: 2025-03-27T22:48:51.368818-05:00
13 changes: 13 additions & 0 deletions dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from dbtsl.api.shared.query_params import (
DimensionValuesQueryParameters,
GroupByParam,
GroupByType,
OrderByGroupBy,
OrderByMetric,
QueryParameters,
Expand Down Expand Up @@ -38,6 +40,17 @@ def _serialize_val(cls, val: Any) -> str:
d += ".descending(True)"
return d

if isinstance(val, GroupByParam):
g: str = ""
if val.type == GroupByType.DIMENSION:
g = f'Dimension("{val.name}")'
elif val.type == GroupByType.ENTITY:
g = f'Entity("{val.name}")'
if val.grain:
grain_str = val.grain.lower()
g += f'.grain("{grain_str}")'
return g

return json.dumps(val)

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,12 @@ def get_query_request_variables(environment_id: int, params: QueryParameters) ->
return {
"savedQuery": None,
"metrics": [{"name": m} for m in strict_params.metrics] if strict_params.metrics is not None else None,
"groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None,
"groupBy": [
{"name": g} if isinstance(g, str) else {"name": g.name, "timeGranularity": g.grain}
for g in strict_params.group_by
]
if strict_params.group_by is not None
else None,
**shared_vars,
}

Expand Down
29 changes: 25 additions & 4 deletions dbtsl/api/shared/query_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, TypedDict, Union


class GroupByType(Enum):
DIMENSION = "dimension"
ENTITY = "entity"


@dataclass(frozen=True)
class GroupByParam:
"""Parameter for a group_by, i.e a dimension or an entity."""

name: str
type: GroupByType
grain: Optional[str]


@dataclass(frozen=True)
class OrderByMetric:
"""Spec for ordering by a metric."""
Expand Down Expand Up @@ -33,7 +48,7 @@ class QueryParameters(TypedDict, total=False):

saved_query: str
metrics: List[str]
group_by: List[str]
group_by: List[Union[GroupByParam, str]]
limit: int
order_by: List[Union[OrderBySpec, str]]
where: List[str]
Expand All @@ -45,7 +60,7 @@ class AdhocQueryParametersStrict:
"""The parameters of an adhoc query, strictly validated."""

metrics: Optional[List[str]]
group_by: Optional[List[str]]
group_by: Optional[List[Union[GroupByParam, str]]]
limit: Optional[int]
order_by: Optional[List[OrderBySpec]]
where: Optional[List[str]]
Expand All @@ -64,7 +79,9 @@ class SavedQueryQueryParametersStrict:


def validate_order_by(
known_metrics: List[str], known_group_bys: List[str], clause: Union[OrderBySpec, str]
known_metrics: List[str],
known_group_bys: List[Union[str, GroupByParam]],
clause: Union[OrderBySpec, str],
) -> OrderBySpec:
"""Validate an order by clause like `-metric_name`."""
if isinstance(clause, OrderByMetric) or isinstance(clause, OrderByGroupBy):
Expand All @@ -77,7 +94,11 @@ def validate_order_by(
if clause in known_metrics:
return OrderByMetric(name=clause, descending=descending)

if clause in known_group_bys or clause == "metric_time":
normalized_known_group_bys = [
known_group_by.name if isinstance(known_group_by, GroupByParam) else known_group_by
for known_group_by in known_group_bys
]
if clause in normalized_known_group_bys or clause == "metric_time":
return OrderByGroupBy(name=clause, descending=descending, grain=None)

# TODO: make this error less strict when server supports order_by type inference.
Expand Down
19 changes: 18 additions & 1 deletion tests/api/adbc/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dbtsl.api.adbc.protocol import ADBCProtocol
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric
from dbtsl.api.shared.query_params import GroupByParam, GroupByType, OrderByGroupBy, OrderByMetric


def test_serialize_val_basic_values() -> None:
Expand Down Expand Up @@ -74,6 +74,23 @@ def test_get_query_sql_simple_query() -> None:
assert sql == expected


def test_get_query_sql_group_by_param() -> None:
sql = ADBCProtocol.get_query_sql(
params={
"metrics": ["a", "b"],
"group_by": [
GroupByParam(name="c", type=GroupByType.DIMENSION, grain="day"),
GroupByParam(name="d", type=GroupByType.ENTITY, grain="week"),
],
}
)
expected = (
"SELECT * FROM {{ semantic_layer.query("
'group_by=[Dimension("c").grain("day"),Entity("d").grain("week")],metrics=["a","b"],read_cache=True) }}'
)
assert sql == expected


def test_get_query_sql_dimension_values_query() -> None:
sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]})
expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}'
Expand Down
15 changes: 14 additions & 1 deletion tests/query_test_cases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from dbtsl import OrderByGroupBy
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.api.shared.query_params import GroupByParam, GroupByType, QueryParameters

TEST_QUERIES: List[QueryParameters] = [
# ad hoc query, all parameters
Expand Down Expand Up @@ -38,4 +38,17 @@
{
"saved_query": "order_metrics",
},
# group by param object
{
"metrics": ["order_total"],
"group_by": [GroupByParam(name="customer__customer_type", grain="month", type=GroupByType.DIMENSION)],
},
# multiple group by param objects
{
"metrics": ["order_total"],
"group_by": [
GroupByParam(name="customer__customer_type", grain="month", type=GroupByType.DIMENSION),
GroupByParam(name="customer__customer_type", grain="week", type=GroupByType.DIMENSION),
],
},
]
28 changes: 28 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dbtsl.api.graphql.util import normalize_query
from dbtsl.api.shared.query_params import (
AdhocQueryParametersStrict,
GroupByParam,
GroupByType,
OrderByGroupBy,
OrderByMetric,
QueryParameters,
Expand Down Expand Up @@ -289,3 +291,29 @@ def test_validate_query_params_no_query() -> None:
p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False}
with pytest.raises(ValueError):
validate_query_parameters(p)


def test_validate_query_params_group_by_param_dimension() -> None:
p: QueryParameters = {
"group_by": [GroupByParam(name="a", grain="day", type=GroupByType.DIMENSION)],
"order_by": ["a"],
}
r = validate_query_parameters(p)
assert isinstance(r, AdhocQueryParametersStrict)
assert r.group_by == [GroupByParam(name="a", grain="day", type=GroupByType.DIMENSION)]


def test_validate_query_params_group_by_param_entity() -> None:
p: QueryParameters = {"group_by": [GroupByParam(name="a", grain="day", type=GroupByType.ENTITY)], "order_by": ["a"]}
r = validate_query_parameters(p)
assert isinstance(r, AdhocQueryParametersStrict)
assert r.group_by == [GroupByParam(name="a", grain="day", type=GroupByType.ENTITY)]


def test_validate_missing_query_params_group_by_param() -> None:
p: QueryParameters = {
"group_by": [GroupByParam(name="b", grain="day", type=GroupByType.DIMENSION)],
"order_by": ["a"],
}
with pytest.raises(ValueError):
validate_query_parameters(p)
Loading