21
21
from dbtsl .api .shared .query_params import QueryParameters
22
22
from dbtsl .backoff import ExponentialBackoff
23
23
from dbtsl .error import ConnectTimeoutError , ExecuteTimeoutError , QueryFailedError , RetryTimeoutError , TimeoutError
24
- from dbtsl .models .query import QueryId , QueryStatus
24
+ from dbtsl .models .query import QueryStatus
25
25
26
26
# aiohttp only started distinguishing between read and connect timeouts after version 3.10
27
27
# If the user is using an older version, we fall back to considering them both the same thing
30
30
31
31
AiohttpServerTimeout = ServerTimeoutError
32
32
AiohttpConnectionTimeout = ConnectionTimeoutError
33
- NEW_AIOHTTP = True
33
+ _new_aiohttp = True
34
34
except ImportError :
35
35
from asyncio import TimeoutError as AsyncioTimeoutError
36
36
37
37
AiohttpServerTimeout = AsyncioTimeoutError
38
38
AiohttpConnectionTimeout = AsyncioTimeoutError
39
- NEW_AIOHTTP = False
39
+ _new_aiohttp = False
40
40
41
41
42
42
class AsyncGraphQLClient (BaseGraphQLClient [AIOHTTPTransport , AsyncClientSession ]):
@@ -75,7 +75,7 @@ def _create_transport(self, url: str, headers: Dict[str, str]) -> AIOHTTPTranspo
75
75
# The following type ignore is OK since gql annotated `timeout` as an `Optional[int]`,
76
76
# but aiohttp allows `float` timeouts
77
77
# See: https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientTimeout
78
- timeout = self .timeout .execute_timeout , # pyright : ignore[reportArgumentType]
78
+ timeout = self .timeout .execute_timeout , # type : ignore
79
79
ssl_close_timeout = self .timeout .tls_close_timeout ,
80
80
)
81
81
@@ -95,21 +95,21 @@ async def session(self) -> AsyncIterator[Self]:
95
95
yield self
96
96
self ._gql_session_unsafe = None
97
97
98
- async def _run (self , op : ProtocolOperation [TVariables , TResponse ], ** kwargs : TVariables ) -> TResponse :
98
+ async def _run (self , op : ProtocolOperation [TVariables , TResponse ], raw_variables : TVariables ) -> TResponse :
99
99
"""Run a `ProtocolOperation`."""
100
100
raw_query = op .get_request_text ()
101
- variables = op .get_request_variables (environment_id = self .environment_id , ** kwargs )
101
+ variables = op .get_request_variables (environment_id = self .environment_id , variables = raw_variables )
102
102
gql_query = gql (raw_query )
103
103
104
104
try :
105
- res = await self ._gql_session .execute (gql_query , variable_values = variables )
105
+ res = await self ._gql_session .execute (gql_query , variable_values = variables ) # type: ignore
106
106
except AiohttpConnectionTimeout as err :
107
- if NEW_AIOHTTP :
107
+ if _new_aiohttp :
108
108
raise ConnectTimeoutError (timeout_s = self .timeout .connect_timeout ) from err
109
109
raise TimeoutError (timeout_s = self .timeout .total_timeout ) from err
110
110
# I found out by trial and error that aiohttp can raise all these different kinds of errors
111
111
# depending on where the timeout happened in the stack (aiohttp, anyio, asyncio)
112
- except (AiohttpServerTimeout , asyncio .TimeoutError , BuiltinTimeoutError ) as err :
112
+ except (AiohttpServerTimeout , asyncio .TimeoutError , BuiltinTimeoutError ) as err : # type: ignore
113
113
raise ExecuteTimeoutError (timeout_s = self .timeout .execute_timeout ) from err
114
114
except Exception as err :
115
115
raise self ._refine_err (err )
@@ -118,10 +118,9 @@ async def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVa
118
118
119
119
async def _poll_until_complete (
120
120
self ,
121
- query_id : QueryId ,
122
121
poll_op : ProtocolOperation [TJobStatusVariables , TJobStatusResult ],
122
+ variables : TJobStatusVariables ,
123
123
backoff : Optional [ExponentialBackoff ] = None ,
124
- ** kwargs ,
125
124
) -> TJobStatusResult :
126
125
"""Poll for a job's results until it is in a completed state (SUCCESSFUL or FAILED)."""
127
126
if backoff is None :
@@ -132,8 +131,7 @@ async def _poll_until_complete(
132
131
133
132
start_s = time .time ()
134
133
for sleep_ms in backoff .iter_ms ():
135
- kwargs ["query_id" ] = query_id
136
- qr = await self ._run (poll_op , ** kwargs )
134
+ qr = await self ._run (op = poll_op , raw_variables = variables )
137
135
if qr .status in (QueryStatus .SUCCESSFUL , QueryStatus .FAILED ):
138
136
return qr
139
137
@@ -149,7 +147,10 @@ async def _poll_until_complete(
149
147
async def query (self , ** params : Unpack [QueryParameters ]) -> "pa.Table" :
150
148
"""Query the Semantic Layer."""
151
149
query_id = await self .create_query (** params )
152
- first_page_results = await self ._poll_until_complete (query_id , self .PROTOCOL .get_query_result , page_num = 1 )
150
+ first_page_results = await self ._poll_until_complete (
151
+ poll_op = self .PROTOCOL .get_query_result ,
152
+ variables = {"query_id" : query_id , "page_num" : 1 },
153
+ )
153
154
if first_page_results .status != QueryStatus .SUCCESSFUL :
154
155
raise QueryFailedError ()
155
156
@@ -164,5 +165,5 @@ async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
164
165
]
165
166
all_page_results = [first_page_results ] + await asyncio .gather (* tasks )
166
167
tables = [r .result_table for r in all_page_results ]
167
- final_table = pa .concat_tables (tables )
168
+ final_table = pa .concat_tables (tables ) # type: ignore
168
169
return final_table
0 commit comments