Skip to content

Commit c56ac1a

Browse files
authored
chore: make pyright run in strict mode and add mypy (#78)
* ci: add mypy This commit introduces mypy to our pre-commit hooks and CI. This should increase the amount of typing bugs we catch. * refactor: make pyright and mypy strict This commit makes both pyright and mypy operate on the strictest setting. I had to change some code around or add type ignores in cases where it's being too strict or when the stubs for dependencies are not precise enough. I only did this for the `dbtsl` package in this commit. I'll refactor the tests in the next commit. * test: pyright in tests This commit changes some of our tests to work with stricter pyright settings. I made the checks for tests a little less strict than the real `dbtsl` tests since our tests sometimes inject wrong types into the API to purposefully induce errors, or mocks stuff that's really hard to type. * docs: changelog
1 parent f72a118 commit c56ac1a

27 files changed

+149
-108
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kind: Chore
2+
body: Make `pyright` run on `strict` mode and add `mypy`.
3+
time: 2025-04-03T16:30:34.024926+02:00

.github/workflows/code-quality.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ jobs:
2929
run: "hatch run dev:ruff format --check"
3030

3131
- name: basedpyright
32-
run: "hatch run dev:basedpyright"
32+
run: "hatch run dev:basedpyright dbtsl tests"
33+
34+
- name: mypy
35+
run: "hatch run dev:python -m mypy dbtsl"
3336

3437
- name: fetch server schema
3538
run: "hatch run dev:fetch-schema"

dbtsl/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# pyright: reportUnusedImport=false
1+
# type: ignore
22
try:
33
from dbtsl.client.sync import SyncSemanticLayerClient
44

55
SemanticLayerClient = SyncSemanticLayerClient
66
except ImportError:
77

8-
def err_factory(*args, **kwargs) -> None: # noqa: D103
8+
def err_factory(*_args: object, **_kwargs: object) -> None: # noqa: D103
99
raise ImportError(
1010
"You are trying to use the default `SemanticLayerClient`, "
1111
"but it looks like the necessary dependencies were not installed. "

dbtsl/api/adbc/client/asyncio.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from typing_extensions import Self, Unpack
77

88
from dbtsl.api.adbc.client.base import BaseADBCClient
9-
from dbtsl.api.adbc.protocol import QueryParameters
10-
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters
9+
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters
1110

1211

1312
class AsyncADBCClient(BaseADBCClient):
@@ -59,7 +58,7 @@ async def query(self, **query_params: Unpack[QueryParameters]) -> pa.Table:
5958
# just creating the cursor object doesn't perform any blocking IO.
6059
with self._conn.cursor() as cur:
6160
try:
62-
await self._loop.run_in_executor(None, cur.execute, query_sql)
61+
await self._loop.run_in_executor(None, cur.execute, query_sql) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType]
6362
except Exception as err:
6463
self._handle_error(err)
6564
table = await self._loop.run_in_executor(None, cur.fetch_arrow_table)
@@ -74,7 +73,8 @@ async def dimension_values(self, **query_params: Unpack[DimensionValuesQueryPara
7473
# just creating the cursor object doesn't perform any blocking IO.
7574
with self._conn.cursor() as cur:
7675
try:
77-
await self._loop.run_in_executor(None, cur.execute, query_sql)
76+
await self._loop.run_in_executor(None, cur.execute, query_sql) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType]
77+
7878
except Exception as err:
7979
self._handle_error(err)
8080
table = await self._loop.run_in_executor(None, cur.fetch_arrow_table)

dbtsl/api/adbc/client/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from adbc_driver_flightsql import DatabaseOptions
66
from adbc_driver_flightsql.dbapi import Connection
7-
from adbc_driver_flightsql.dbapi import connect as adbc_connect
7+
from adbc_driver_flightsql.dbapi import connect as adbc_connect # pyright: ignore[reportUnknownVariableType]
88
from adbc_driver_manager import AdbcStatusCode, ProgrammingError
99

1010
import dbtsl.env as env

dbtsl/api/adbc/client/sync.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from typing_extensions import Self, Unpack
66

77
from dbtsl.api.adbc.client.base import BaseADBCClient
8-
from dbtsl.api.adbc.protocol import QueryParameters
9-
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters
8+
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters
109

1110

1211
class SyncADBCClient(BaseADBCClient):
@@ -53,7 +52,7 @@ def query(self, **query_params: Unpack[QueryParameters]) -> pa.Table:
5352

5453
with self._conn.cursor() as cur:
5554
try:
56-
cur.execute(query_sql)
55+
cur.execute(query_sql) # pyright: ignore[reportUnknownMemberType]
5756
except Exception as err:
5857
self._handle_error(err)
5958

@@ -67,7 +66,7 @@ def dimension_values(self, **query_params: Unpack[DimensionValuesQueryParameters
6766

6867
with self._conn.cursor() as cur:
6968
try:
70-
cur.execute(query_sql)
69+
cur.execute(query_sql) # pyright: ignore[reportUnknownMemberType]
7170
except Exception as err:
7271
self._handle_error(err)
7372

dbtsl/api/adbc/protocol.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _serialize_val(cls, val: Any) -> str:
2020
return str(val)
2121

2222
if isinstance(val, list):
23-
list_str = ",".join(cls._serialize_val(list_val) for list_val in val)
23+
list_str = ",".join(cls._serialize_val(list_val) for list_val in val) # pyright: ignore[reportUnknownVariableType]
2424
return f"[{list_str}]"
2525

2626
if isinstance(val, OrderByMetric):

dbtsl/api/graphql/client/asyncio.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from dbtsl.api.shared.query_params import QueryParameters
2222
from dbtsl.backoff import ExponentialBackoff
2323
from dbtsl.error import ConnectTimeoutError, ExecuteTimeoutError, QueryFailedError, RetryTimeoutError, TimeoutError
24-
from dbtsl.models.query import QueryId, QueryStatus
24+
from dbtsl.models.query import QueryStatus
2525

2626
# aiohttp only started distinguishing between read and connect timeouts after version 3.10
2727
# If the user is using an older version, we fall back to considering them both the same thing
@@ -30,13 +30,13 @@
3030

3131
AiohttpServerTimeout = ServerTimeoutError
3232
AiohttpConnectionTimeout = ConnectionTimeoutError
33-
NEW_AIOHTTP = True
33+
_new_aiohttp = True
3434
except ImportError:
3535
from asyncio import TimeoutError as AsyncioTimeoutError
3636

3737
AiohttpServerTimeout = AsyncioTimeoutError
3838
AiohttpConnectionTimeout = AsyncioTimeoutError
39-
NEW_AIOHTTP = False
39+
_new_aiohttp = False
4040

4141

4242
class AsyncGraphQLClient(BaseGraphQLClient[AIOHTTPTransport, AsyncClientSession]):
@@ -75,7 +75,7 @@ def _create_transport(self, url: str, headers: Dict[str, str]) -> AIOHTTPTranspo
7575
# The following type ignore is OK since gql annotated `timeout` as an `Optional[int]`,
7676
# but aiohttp allows `float` timeouts
7777
# See: https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientTimeout
78-
timeout=self.timeout.execute_timeout, # pyright: ignore[reportArgumentType]
78+
timeout=self.timeout.execute_timeout, # type: ignore
7979
ssl_close_timeout=self.timeout.tls_close_timeout,
8080
)
8181

@@ -95,21 +95,21 @@ async def session(self) -> AsyncIterator[Self]:
9595
yield self
9696
self._gql_session_unsafe = None
9797

98-
async def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariables) -> TResponse:
98+
async def _run(self, op: ProtocolOperation[TVariables, TResponse], raw_variables: TVariables) -> TResponse:
9999
"""Run a `ProtocolOperation`."""
100100
raw_query = op.get_request_text()
101-
variables = op.get_request_variables(environment_id=self.environment_id, **kwargs)
101+
variables = op.get_request_variables(environment_id=self.environment_id, variables=raw_variables)
102102
gql_query = gql(raw_query)
103103

104104
try:
105-
res = await self._gql_session.execute(gql_query, variable_values=variables)
105+
res = await self._gql_session.execute(gql_query, variable_values=variables) # type: ignore
106106
except AiohttpConnectionTimeout as err:
107-
if NEW_AIOHTTP:
107+
if _new_aiohttp:
108108
raise ConnectTimeoutError(timeout_s=self.timeout.connect_timeout) from err
109109
raise TimeoutError(timeout_s=self.timeout.total_timeout) from err
110110
# I found out by trial and error that aiohttp can raise all these different kinds of errors
111111
# depending on where the timeout happened in the stack (aiohttp, anyio, asyncio)
112-
except (AiohttpServerTimeout, asyncio.TimeoutError, BuiltinTimeoutError) as err:
112+
except (AiohttpServerTimeout, asyncio.TimeoutError, BuiltinTimeoutError) as err: # type: ignore
113113
raise ExecuteTimeoutError(timeout_s=self.timeout.execute_timeout) from err
114114
except Exception as err:
115115
raise self._refine_err(err)
@@ -118,10 +118,9 @@ async def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVa
118118

119119
async def _poll_until_complete(
120120
self,
121-
query_id: QueryId,
122121
poll_op: ProtocolOperation[TJobStatusVariables, TJobStatusResult],
122+
variables: TJobStatusVariables,
123123
backoff: Optional[ExponentialBackoff] = None,
124-
**kwargs,
125124
) -> TJobStatusResult:
126125
"""Poll for a job's results until it is in a completed state (SUCCESSFUL or FAILED)."""
127126
if backoff is None:
@@ -132,8 +131,7 @@ async def _poll_until_complete(
132131

133132
start_s = time.time()
134133
for sleep_ms in backoff.iter_ms():
135-
kwargs["query_id"] = query_id
136-
qr = await self._run(poll_op, **kwargs)
134+
qr = await self._run(op=poll_op, raw_variables=variables)
137135
if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED):
138136
return qr
139137

@@ -149,7 +147,10 @@ async def _poll_until_complete(
149147
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
150148
"""Query the Semantic Layer."""
151149
query_id = await self.create_query(**params)
152-
first_page_results = await self._poll_until_complete(query_id, self.PROTOCOL.get_query_result, page_num=1)
150+
first_page_results = await self._poll_until_complete(
151+
poll_op=self.PROTOCOL.get_query_result,
152+
variables={"query_id": query_id, "page_num": 1},
153+
)
153154
if first_page_results.status != QueryStatus.SUCCESSFUL:
154155
raise QueryFailedError()
155156

@@ -164,5 +165,5 @@ async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
164165
]
165166
all_page_results = [first_page_results] + await asyncio.gather(*tasks)
166167
tables = [r.result_table for r in all_page_results]
167-
final_table = pa.concat_tables(tables)
168+
final_table = pa.concat_tables(tables) # type: ignore
168169
return final_table

dbtsl/api/graphql/client/asyncio.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# mypy: disable-error-code="misc"
2+
13
from contextlib import AbstractAsyncContextManager
24
from typing import List, Optional, Self, Union
35

dbtsl/api/graphql/client/base.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import functools
21
import warnings
32
from abc import abstractmethod
43
from typing import Any, Dict, Generic, Optional, Protocol, TypeVar, Union
@@ -126,13 +125,15 @@ def __getattr__(self, attr: str) -> Any:
126125

127126
with warnings.catch_warnings():
128127
warnings.simplefilter("ignore", DeprecationWarning)
129-
return functools.partial(
130-
self._run,
131-
op=op,
132-
)
128+
129+
def wrapped(**kwargs: Any) -> Any:
130+
return self._run(op=op, raw_variables=kwargs)
131+
132+
return wrapped
133133

134134

135-
TClient = TypeVar("TClient", bound=BaseGraphQLClient, covariant=True)
135+
# TODO: have to type ignore, see: https://github.com/microsoft/pyright/issues/3497
136+
TClient = TypeVar("TClient", bound=BaseGraphQLClient, covariant=True) # type: ignore
136137

137138

138139
class GraphQLClientFactory(Protocol, Generic[TClient]): # noqa: D101

dbtsl/api/graphql/client/sync.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from dbtsl.api.shared.query_params import QueryParameters
2626
from dbtsl.backoff import ExponentialBackoff
2727
from dbtsl.error import ConnectTimeoutError, ExecuteTimeoutError, QueryFailedError, RetryTimeoutError
28-
from dbtsl.models.query import QueryId, QueryStatus
28+
from dbtsl.models.query import QueryStatus
2929

3030

3131
class SyncGraphQLClient(BaseGraphQLClient[RequestsHTTPTransport, SyncClientSession]):
@@ -64,7 +64,7 @@ def _create_transport(self, url: str, headers: Dict[str, str]) -> RequestsHTTPTr
6464
# but requests allows `tuple[float, float]` timeouts
6565
# See: https://github.com/graphql-python/gql/blob/b066e8944b0da0a4bbac6c31f43e5c3c7772cd51/gql/transport/requests.py#L393
6666
# See: https://requests.readthedocs.io/en/latest/user/advanced/#timeouts
67-
timeout=(self.timeout.connect_timeout, self.timeout.execute_timeout), # pyright: ignore[reportArgumentType]
67+
timeout=(self.timeout.connect_timeout, self.timeout.execute_timeout), # type: ignore
6868
)
6969

7070
@contextmanager
@@ -83,14 +83,14 @@ def session(self) -> Iterator[Self]:
8383
yield self
8484
self._gql_session_unsafe = None
8585

86-
def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariables) -> TResponse:
86+
def _run(self, op: ProtocolOperation[TVariables, TResponse], raw_variables: TVariables) -> TResponse:
8787
"""Run a `ProtocolOperation`."""
8888
raw_query = op.get_request_text()
89-
variables = op.get_request_variables(environment_id=self.environment_id, **kwargs)
89+
variables = op.get_request_variables(environment_id=self.environment_id, variables=raw_variables)
9090
gql_query = gql(raw_query)
9191

9292
try:
93-
res = self._gql_session.execute(gql_query, variable_values=variables)
93+
res = self._gql_session.execute(gql_query, variable_values=variables) # type: ignore
9494
except RequestsReadTimeout as err:
9595
raise ExecuteTimeoutError(timeout_s=self.timeout.execute_timeout) from err
9696
except RequestsConnectTimeout as err:
@@ -102,10 +102,9 @@ def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariable
102102

103103
def _poll_until_complete(
104104
self,
105-
query_id: QueryId,
106105
poll_op: ProtocolOperation[TJobStatusVariables, TJobStatusResult],
106+
variables: TJobStatusVariables,
107107
backoff: Optional[ExponentialBackoff] = None,
108-
**kwargs,
109108
) -> TJobStatusResult:
110109
"""Poll for a query's results until it is in a completed state (SUCCESSFUL or FAILED).
111110
@@ -120,9 +119,7 @@ def _poll_until_complete(
120119

121120
start_s = time.time()
122121
for sleep_ms in backoff.iter_ms():
123-
kwargs["query_id"] = query_id
124-
125-
qr = self._run(poll_op, **kwargs)
122+
qr = self._run(op=poll_op, raw_variables=variables)
126123
if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED):
127124
return qr
128125

@@ -138,7 +135,13 @@ def _poll_until_complete(
138135
def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
139136
"""Query the Semantic Layer."""
140137
query_id = self.create_query(**params)
141-
first_page_results = self._poll_until_complete(query_id, self.PROTOCOL.get_query_result, page_num=1)
138+
first_page_results = self._poll_until_complete(
139+
poll_op=self.PROTOCOL.get_query_result,
140+
variables={
141+
"query_id": query_id,
142+
"page_num": 1,
143+
},
144+
)
142145
if first_page_results.status != QueryStatus.SUCCESSFUL:
143146
raise QueryFailedError()
144147

@@ -153,5 +156,5 @@ def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
153156
]
154157
all_page_results = [first_page_results] + results
155158
tables = [r.result_table for r in all_page_results]
156-
final_table = pa.concat_tables(tables)
159+
final_table = pa.concat_tables(tables) # type: ignore
157160
return final_table

dbtsl/api/graphql/client/sync.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# mypy: disable-error-code="misc"
2+
13
from contextlib import AbstractContextManager
24
from typing import Iterator, List, Optional, Union
35

0 commit comments

Comments
 (0)