Skip to content

Commit 39b6f91

Browse files
authored
feat: add compile_sql via GraphQL (#44)
* feat: add `compile_sql` via GraphQL This commit adds a new `compile_sql` method to the SDK, which uses the GraphQL API to generate the compiled SQL given a set of query parameters. * test: add integration test for `compile_sql` Added an integration test for `compile_sql` that ensures we get data back from the API for a valid query which contains a `SELECT` statement. * docs: add example for `compile_sql` This commit adds a usage example for `compile_sql`. * docs: add changelog entry
1 parent 13ebfe5 commit 39b6f91

File tree

9 files changed

+131
-3
lines changed

9 files changed

+131
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kind: Features
2+
body: '`compile_sql` method for getting the compiled SQL of a query'
3+
time: 2024-09-20T18:05:50.976574+02:00

dbtsl/api/graphql/client/asyncio.pyi

+7-1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,10 @@ class AsyncGraphQLClient:
4444
"""Get a list of all available saved queries."""
4545
...
4646

47-
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
47+
async def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
48+
"""Get the compiled SQL that would be sent to the warehouse by a query."""
49+
...
50+
51+
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
52+
"""Query the Semantic Layer."""
53+
...

dbtsl/api/graphql/client/sync.pyi

+7-1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,10 @@ class SyncGraphQLClient:
4444
"""Get a list of all available saved queries."""
4545
...
4646

47-
def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
47+
def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
48+
"""Get the compiled SQL that would be sent to the warehouse by a query."""
49+
...
50+
51+
def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
52+
"""Query the Semantic Layer."""
53+
...

dbtsl/api/graphql/protocol.py

+48
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,53 @@ def parse_response(self, data: Dict[str, Any]) -> QueryResult:
277277
return decode_to_dataclass(data["query"], QueryResult)
278278

279279

280+
class CompileSqlOperation(ProtocolOperation[QueryParameters, str]):
281+
"""Get the compiled SQL that would be sent to the warehouse by a query."""
282+
283+
@override
284+
def get_request_text(self) -> str:
285+
query = """
286+
mutation compileSql(
287+
$environmentId: BigInt!,
288+
$metrics: [MetricInput!]!,
289+
$groupBy: [GroupByInput!]!,
290+
$where: [WhereInput!]!,
291+
$orderBy: [OrderByInput!]!,
292+
$limit: Int,
293+
$readCache: Boolean,
294+
) {
295+
compileSql(
296+
environmentId: $environmentId,
297+
metrics: $metrics,
298+
groupBy: $groupBy,
299+
where: $where,
300+
orderBy: $orderBy,
301+
limit: $limit,
302+
readCache: $readCache,
303+
) {
304+
sql
305+
}
306+
}
307+
"""
308+
return query
309+
310+
@override
311+
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
312+
return {
313+
"environmentId": environment_id,
314+
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
315+
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
316+
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
317+
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
318+
"limit": kwargs.get("limit", None),
319+
"readCache": kwargs.get("read_cache", True),
320+
}
321+
322+
@override
323+
def parse_response(self, data: Dict[str, Any]) -> str:
324+
return data["compileSql"]["sql"]
325+
326+
280327
class GraphQLProtocol:
281328
"""Holds the GraphQL implementation for each of method in the API.
282329
@@ -291,3 +338,4 @@ class GraphQLProtocol:
291338
saved_queries = ListSavedQueriesOperation()
292339
create_query = CreateQueryOperation()
293340
get_query_result = GetQueryResultOperation()
341+
compile_sql = CompileSqlOperation()

dbtsl/client/asyncio.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class AsyncSemanticLayerClient:
1414
auth_token: str,
1515
host: str,
1616
) -> None: ...
17+
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
18+
"""Get the compiled SQL that would be sent to the warehouse by a query."""
19+
...
20+
1721
async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
1822
"""Query the Semantic Layer for a metric data."""
1923
...

dbtsl/client/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ class BaseSemanticLayerClient(ABC, Generic[TGQLClient, TADBCClient]):
2121
"""
2222

2323
_METHOD_MAP = {
24+
"compile_sql": GRAPHQL,
2425
"dimension_values": ADBC,
25-
"query": ADBC,
2626
"dimensions": GRAPHQL,
2727
"entities": GRAPHQL,
2828
"measures": GRAPHQL,
2929
"metrics": GRAPHQL,
30+
"query": ADBC,
3031
"saved_queries": GRAPHQL,
3132
}
3233

dbtsl/client/sync.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class SyncSemanticLayerClient:
1414
auth_token: str,
1515
host: str,
1616
) -> None: ...
17+
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
18+
"""Get the compiled SQL that would be sent to the warehouse by a query."""
19+
...
20+
1721
def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
1822
"""Query the Semantic Layer for a metric data."""
1923
...

examples/compile_query_sync.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Compile a query and display the SQL."""
2+
3+
from argparse import ArgumentParser
4+
5+
from dbtsl import SemanticLayerClient
6+
7+
8+
def get_arg_parser() -> ArgumentParser:
9+
p = ArgumentParser()
10+
11+
p.add_argument("metric", help="The metric to fetch")
12+
p.add_argument("group_by", help="A dimension to group by")
13+
p.add_argument("--env-id", required=True, help="The dbt environment ID", type=int)
14+
p.add_argument("--token", required=True, help="The API auth token")
15+
p.add_argument("--host", required=True, help="The API host")
16+
17+
return p
18+
19+
20+
def main() -> None:
21+
arg_parser = get_arg_parser()
22+
args = arg_parser.parse_args()
23+
24+
client = SemanticLayerClient(
25+
environment_id=args.env_id,
26+
auth_token=args.token,
27+
host=args.host,
28+
)
29+
30+
with client.session():
31+
sql = client.compile_sql(
32+
metrics=[args.metric],
33+
group_by=[args.group_by],
34+
limit=15,
35+
)
36+
print(f"Compiled SQL for {args.metric} grouped by {args.group_by}, limit 15:")
37+
print(sql)
38+
39+
40+
if __name__ == "__main__":
41+
main()

tests/integration/test_sl_client.py

+15
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,18 @@ async def test_client_query_works(api: str, client: BothClients) -> None:
9292
)
9393
)
9494
assert len(table) > 0
95+
96+
97+
async def test_client_compile_sql_works(client: BothClients) -> None:
98+
metrics = await maybe_await(client.metrics())
99+
assert len(metrics) > 0
100+
101+
sql = await maybe_await(
102+
client.compile_sql(
103+
metrics=[metrics[0].name],
104+
group_by=[metrics[0].dimensions[0].name],
105+
limit=1,
106+
)
107+
)
108+
assert len(sql) > 0
109+
assert "SELECT" in sql

0 commit comments

Comments
 (0)