From 7455034461628d2ed566c03ad8143ccc7b244df3 Mon Sep 17 00:00:00 2001 From: Devon Fulcher <24593113+DevonFulcher@users.noreply.github.com> Date: Mon, 31 Mar 2025 19:46:31 -0500 Subject: [PATCH] Add GroupByParam --- .../unreleased/Features-20250327-224851.yaml | 3 ++ dbtsl/api/adbc/protocol.py | 13 +++++++++ dbtsl/api/graphql/protocol.py | 7 ++++- dbtsl/api/shared/query_params.py | 29 ++++++++++++++++--- tests/api/adbc/test_protocol.py | 19 +++++++++++- tests/query_test_cases.py | 15 +++++++++- tests/test_models.py | 28 ++++++++++++++++++ 7 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 .changes/unreleased/Features-20250327-224851.yaml diff --git a/.changes/unreleased/Features-20250327-224851.yaml b/.changes/unreleased/Features-20250327-224851.yaml new file mode 100644 index 0000000..5419e16 --- /dev/null +++ b/.changes/unreleased/Features-20250327-224851.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Add support to pass in typed group by +time: 2025-03-27T22:48:51.368818-05:00 diff --git a/dbtsl/api/adbc/protocol.py b/dbtsl/api/adbc/protocol.py index b1a58b9..6b9d88d 100644 --- a/dbtsl/api/adbc/protocol.py +++ b/dbtsl/api/adbc/protocol.py @@ -4,6 +4,8 @@ from dbtsl.api.shared.query_params import ( DimensionValuesQueryParameters, + GroupByParam, + GroupByType, OrderByGroupBy, OrderByMetric, QueryParameters, @@ -38,6 +40,17 @@ def _serialize_val(cls, val: Any) -> str: d += ".descending(True)" return d + if isinstance(val, GroupByParam): + g: str = "" + if val.type == GroupByType.DIMENSION: + g = f'Dimension("{val.name}")' + elif val.type == GroupByType.ENTITY: + g = f'Entity("{val.name}")' + if val.grain: + grain_str = val.grain.lower() + g += f'.grain("{grain_str}")' + return g + return json.dumps(val) @classmethod diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index 4115f99..7954a1c 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -220,7 +220,12 @@ def get_query_request_variables(environment_id: int, params: QueryParameters) -> return { "savedQuery": None, "metrics": [{"name": m} for m in strict_params.metrics] if strict_params.metrics is not None else None, - "groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None, + "groupBy": [ + {"name": g} if isinstance(g, str) else {"name": g.name, "timeGranularity": g.grain} + for g in strict_params.group_by + ] + if strict_params.group_by is not None + else None, **shared_vars, } diff --git a/dbtsl/api/shared/query_params.py b/dbtsl/api/shared/query_params.py index 3537715..1bb1fbc 100644 --- a/dbtsl/api/shared/query_params.py +++ b/dbtsl/api/shared/query_params.py @@ -1,7 +1,22 @@ from dataclasses import dataclass +from enum import Enum from typing import List, Optional, TypedDict, Union +class GroupByType(Enum): + DIMENSION = "dimension" + ENTITY = "entity" + + +@dataclass(frozen=True) +class GroupByParam: + """Parameter for a group_by, i.e a dimension or an entity.""" + + name: str + type: GroupByType + grain: Optional[str] + + @dataclass(frozen=True) class OrderByMetric: """Spec for ordering by a metric.""" @@ -33,7 +48,7 @@ class QueryParameters(TypedDict, total=False): saved_query: str metrics: List[str] - group_by: List[str] + group_by: List[Union[GroupByParam, str]] limit: int order_by: List[Union[OrderBySpec, str]] where: List[str] @@ -45,7 +60,7 @@ class AdhocQueryParametersStrict: """The parameters of an adhoc query, strictly validated.""" metrics: Optional[List[str]] - group_by: Optional[List[str]] + group_by: Optional[List[Union[GroupByParam, str]]] limit: Optional[int] order_by: Optional[List[OrderBySpec]] where: Optional[List[str]] @@ -64,7 +79,9 @@ class SavedQueryQueryParametersStrict: def validate_order_by( - known_metrics: List[str], known_group_bys: List[str], clause: Union[OrderBySpec, str] + known_metrics: List[str], + known_group_bys: List[Union[str, GroupByParam]], + clause: Union[OrderBySpec, str], ) -> OrderBySpec: """Validate an order by clause like `-metric_name`.""" if isinstance(clause, OrderByMetric) or isinstance(clause, OrderByGroupBy): @@ -77,7 +94,11 @@ def validate_order_by( if clause in known_metrics: return OrderByMetric(name=clause, descending=descending) - if clause in known_group_bys or clause == "metric_time": + normalized_known_group_bys = [ + known_group_by.name if isinstance(known_group_by, GroupByParam) else known_group_by + for known_group_by in known_group_bys + ] + if clause in normalized_known_group_bys or clause == "metric_time": return OrderByGroupBy(name=clause, descending=descending, grain=None) # TODO: make this error less strict when server supports order_by type inference. diff --git a/tests/api/adbc/test_protocol.py b/tests/api/adbc/test_protocol.py index b3a40e0..99d9910 100644 --- a/tests/api/adbc/test_protocol.py +++ b/tests/api/adbc/test_protocol.py @@ -1,5 +1,5 @@ from dbtsl.api.adbc.protocol import ADBCProtocol -from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric +from dbtsl.api.shared.query_params import GroupByParam, GroupByType, OrderByGroupBy, OrderByMetric def test_serialize_val_basic_values() -> None: @@ -74,6 +74,23 @@ def test_get_query_sql_simple_query() -> None: assert sql == expected +def test_get_query_sql_group_by_param() -> None: + sql = ADBCProtocol.get_query_sql( + params={ + "metrics": ["a", "b"], + "group_by": [ + GroupByParam(name="c", type=GroupByType.DIMENSION, grain="day"), + GroupByParam(name="d", type=GroupByType.ENTITY, grain="week"), + ], + } + ) + expected = ( + "SELECT * FROM {{ semantic_layer.query(" + 'group_by=[Dimension("c").grain("day"),Entity("d").grain("week")],metrics=["a","b"],read_cache=True) }}' + ) + assert sql == expected + + def test_get_query_sql_dimension_values_query() -> None: sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]}) expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}' diff --git a/tests/query_test_cases.py b/tests/query_test_cases.py index 7480937..de0a578 100644 --- a/tests/query_test_cases.py +++ b/tests/query_test_cases.py @@ -1,7 +1,7 @@ from typing import List from dbtsl import OrderByGroupBy -from dbtsl.api.shared.query_params import QueryParameters +from dbtsl.api.shared.query_params import GroupByParam, GroupByType, QueryParameters TEST_QUERIES: List[QueryParameters] = [ # ad hoc query, all parameters @@ -38,4 +38,17 @@ { "saved_query": "order_metrics", }, + # group by param object + { + "metrics": ["order_total"], + "group_by": [GroupByParam(name="customer__customer_type", grain="month", type=GroupByType.DIMENSION)], + }, + # multiple group by param objects + { + "metrics": ["order_total"], + "group_by": [ + GroupByParam(name="customer__customer_type", grain="month", type=GroupByType.DIMENSION), + GroupByParam(name="customer__customer_type", grain="week", type=GroupByType.DIMENSION), + ], + }, ] diff --git a/tests/test_models.py b/tests/test_models.py index 6d73b58..8adf211 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,6 +13,8 @@ from dbtsl.api.graphql.util import normalize_query from dbtsl.api.shared.query_params import ( AdhocQueryParametersStrict, + GroupByParam, + GroupByType, OrderByGroupBy, OrderByMetric, QueryParameters, @@ -289,3 +291,29 @@ def test_validate_query_params_no_query() -> None: p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False} with pytest.raises(ValueError): validate_query_parameters(p) + + +def test_validate_query_params_group_by_param_dimension() -> None: + p: QueryParameters = { + "group_by": [GroupByParam(name="a", grain="day", type=GroupByType.DIMENSION)], + "order_by": ["a"], + } + r = validate_query_parameters(p) + assert isinstance(r, AdhocQueryParametersStrict) + assert r.group_by == [GroupByParam(name="a", grain="day", type=GroupByType.DIMENSION)] + + +def test_validate_query_params_group_by_param_entity() -> None: + p: QueryParameters = {"group_by": [GroupByParam(name="a", grain="day", type=GroupByType.ENTITY)], "order_by": ["a"]} + r = validate_query_parameters(p) + assert isinstance(r, AdhocQueryParametersStrict) + assert r.group_by == [GroupByParam(name="a", grain="day", type=GroupByType.ENTITY)] + + +def test_validate_missing_query_params_group_by_param() -> None: + p: QueryParameters = { + "group_by": [GroupByParam(name="b", grain="day", type=GroupByType.DIMENSION)], + "order_by": ["a"], + } + with pytest.raises(ValueError): + validate_query_parameters(p)