Skip to content

Commit e412122

Browse files
committed
Fix validate_order_by
1 parent b32a290 commit e412122

File tree

2 files changed

+14
-24
lines changed

2 files changed

+14
-24
lines changed

dbtsl/api/graphql/protocol.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,17 @@ def get_query_request_variables(environment_id: int, params: QueryParameters) ->
202202
"""Get the GraphQL request variables for a given set of query parameters."""
203203
strict_params = validate_query_parameters(params) # type: ignore
204204

205-
order_by_vars = []
206-
for clause in strict_params.order_by or []:
207-
if isinstance(clause, OrderByMetric):
208-
order_by_vars.append({"metric": {"name": clause.name}, "descending": clause.descending})
209-
else:
210-
assert isinstance(clause, OrderByGroupBy)
211-
order_by_vars.append(
212-
{"groupBy": {"name": clause.name, "timeGranularity": clause.grain}, "descending": clause.descending}
213-
)
214-
215205
shared_vars = {
216206
"environmentId": environment_id,
217207
"where": [{"sql": sql} for sql in strict_params.where] if strict_params.where is not None else None,
218-
"orderBy": order_by_vars if strict_params.order_by is not None else None,
208+
"orderBy": [
209+
{"metric": {"name": clause.name}, "descending": clause.descending}
210+
if isinstance(clause, OrderByMetric)
211+
else {"groupBy": {"name": clause.name, "timeGranularity": clause.grain}, "descending": clause.descending}
212+
for clause in strict_params.order_by
213+
]
214+
if strict_params.order_by is not None
215+
else None,
219216
"limit": strict_params.limit,
220217
"readCache": strict_params.read_cache,
221218
}

dbtsl/api/shared/query_params.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class QueryParameters(TypedDict, total=False):
5353
class AdhocQueryParametersStrict:
5454
"""The parameters of an adhoc query, strictly validated."""
5555

56-
metrics: Optional[List[Union[str, OrderByMetric]]]
57-
group_by: Optional[List[Union[str, OrderByGroupBy]]]
56+
metrics: Optional[List[str]]
57+
group_by: Optional[List[Union[GroupByParam, str]]]
5858
limit: Optional[int]
5959
order_by: Optional[List[OrderBySpec]]
6060
where: Optional[List[str]]
@@ -73,8 +73,8 @@ class SavedQueryQueryParametersStrict:
7373

7474

7575
def validate_order_by(
76-
known_metrics: List[Union[str, OrderByMetric]],
77-
known_group_bys: List[Union[str, OrderByGroupBy]],
76+
known_metrics: List[str],
77+
known_group_bys: List[Union[str, GroupByParam]],
7878
clause: Union[OrderBySpec, str],
7979
) -> OrderBySpec:
8080
"""Validate an order by clause like `-metric_name`."""
@@ -85,19 +85,12 @@ def validate_order_by(
8585
if descending or clause.startswith("+"):
8686
clause = clause[1:]
8787

88-
normalized_known_metrics: list[str] = []
89-
for known_metric in known_metrics:
90-
if isinstance(known_metric, OrderByMetric):
91-
normalized_known_metrics.append(known_metric.name)
92-
else:
93-
normalized_known_metrics.append(known_metric)
94-
95-
if clause in normalized_known_metrics:
88+
if clause in known_metrics:
9689
return OrderByMetric(name=clause, descending=descending)
9790

9891
normalized_known_group_bys: list[str] = []
9992
for known_group_by in known_group_bys:
100-
if isinstance(known_group_by, OrderByGroupBy):
93+
if isinstance(known_group_by, GroupByParam):
10194
normalized_known_group_bys.append(known_group_by.name)
10295
else:
10396
normalized_known_group_bys.append(known_group_by)

0 commit comments

Comments
 (0)