|
17 | 17 | class QueryExecutor: |
18 | 18 | """A class for executing queries.""" |
19 | 19 |
|
20 | | - def __init__(self, session=None, transport_kwargs={}): |
| 20 | + def __init__(self, session=None, transport_kwargs={}, connect_on_init=True): |
21 | 21 | self.gql_client = None |
| 22 | + self.session = None |
22 | 23 | transport = AIOHTTPTransport( |
23 | 24 | url=API_ENDPOINT, |
24 | 25 | headers={"Authorization": "Bearer " + self.token}, |
25 | 26 | ssl=True, |
26 | 27 | **transport_kwargs, |
27 | 28 | ) |
28 | 29 | self.gql_client = gql.Client( |
29 | | - transport=transport, fetch_schema_from_transport=True |
| 30 | + transport=transport, fetch_schema_from_transport=False |
30 | 31 | ) |
31 | 32 |
|
32 | | - asyncio.run(self.__ainit__(session)) |
| 33 | + if connect_on_init: |
| 34 | + asyncio.run(self.__ainit__(session)) |
33 | 35 |
|
34 | 36 | async def __ainit__(self, session): |
35 | 37 | self.session = session or await self.gql_client.connect_async() |
36 | 38 | asyncio_atexit.register(self.gql_client.close_async) |
| 39 | + |
| 40 | + async def _ensure_connected(self): |
| 41 | + """Ensure the client is connected before executing queries.""" |
| 42 | + if self.session is None: |
| 43 | + self.session = await self.gql_client.connect_async() |
| 44 | + asyncio_atexit.register(self.gql_client.close_async) |
37 | 45 |
|
38 | 46 | def execute_query( |
39 | 47 | self, access_token: str, query: str, max_tries: int = 1, **kwargs |
@@ -89,6 +97,7 @@ async def execute_async( |
89 | 97 | return result |
90 | 98 |
|
91 | 99 | async def execute_async_single(self, access_token: str, query: str): |
| 100 | + await self._ensure_connected() |
92 | 101 | try: |
93 | 102 | result = await self.gql_client.execute_async(gql.gql(query)) |
94 | 103 | except TransportQueryError as e: |
|
0 commit comments