Skip to content

Commit 08c2cca

Browse files
mlyublenamdesmet
andcommitted
Make ClientSession thread-safe
cherry-pick of trinodb/trino-python-client@79a4814 Co-authored-by: Michiel De Smet <[email protected]>
1 parent 312f11e commit 08c2cca

File tree

5 files changed

+91
-75
lines changed

5 files changed

+91
-75
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Following example shows a use case where both Kerberos and Oauth authentication
6363
```python
6464
import getpass
6565
import prestodb
66-
from prestodb.client import PrestoRequest, PrestoQuery
66+
from prestodb.client import ClientSession, PrestoRequest, PrestoQuery
6767
from requests_kerberos import DISABLED
6868

6969
kerberos_auth = prestodb.auth.KerberosAuthentication(
@@ -76,7 +76,7 @@ kerberos_auth = prestodb.auth.KerberosAuthentication(
7676
req = PrestoRequest(
7777
host='GCP coordinator url',
7878
port=443,
79-
user=getpass.getuser(),
79+
client_session=ClientSession(user=getpass.getuser()),
8080
service_account_file='Service account json file path',
8181
http_scheme='https',
8282
auth=kerberos_auth

integration_tests/fixtures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import pytest
2525
import requests
26-
from prestodb.client import PrestoQuery, PrestoRequest
26+
from prestodb.client import ClientSession, PrestoQuery, PrestoRequest
2727
from prestodb.constants import DEFAULT_PORT
2828
from prestodb.exceptions import TimeoutError
2929

@@ -110,7 +110,7 @@ def start_presto(image_tag=None, build=True, with_cache=True):
110110

111111

112112
def wait_for_presto_workers(host, port, timeout=30):
113-
request = PrestoRequest(host=host, port=port, user="test_fixture")
113+
request = PrestoRequest(host=host, port=port, client_session=ClientSession(user="test_fixture"))
114114
sql = "SELECT state FROM system.runtime.nodes"
115115
t0 = time.time()
116116
while True:

prestodb/client.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
import logging
3838
import os
3939
from typing import Any, Dict, List, Optional, Text, Tuple, Union # NOQA for mypy types
40-
import six.moves.urllib_parse as parse
4140

4241
import prestodb.redirect
4342
import requests
43+
import six.moves.urllib_parse as parse
4444
from prestodb import constants, exceptions
4545
from prestodb.transaction import NO_TRANSACTION
4646

@@ -62,10 +62,10 @@
6262
class ClientSession(object):
6363
def __init__(
6464
self,
65-
catalog,
66-
schema,
67-
source,
6865
user,
66+
catalog = None,
67+
schema = None,
68+
source = None,
6969
properties=None,
7070
headers=None,
7171
transaction_id=None,
@@ -80,6 +80,9 @@ def __init__(
8080
self._headers = headers or {}
8181
self.transaction_id = transaction_id
8282

83+
def __repr__(self):
84+
return f"ClientSession({self.catalog}, {self.schema}, {self.source}, {self.user}, {self._properties}, {self._headers}, {self.transaction_id})"
85+
8386
@property
8487
def properties(self):
8588
return self._properties
@@ -199,11 +202,7 @@ def __init__(
199202
self,
200203
host, # type: Text
201204
port, # type: int
202-
user, # type: Text
203-
source=None, # type: Text
204-
catalog=None, # type: Text
205-
schema=None, # type: Text
206-
session_properties=None, # type: Optional[Dict[Text, Any]]
205+
client_session, # type: ClientSession
207206
http_session=None, # type: Any
208207
http_headers=None, # type: Optional[Dict[Text, Text]]
209208
transaction_id=NO_TRANSACTION, # type: Optional[Text]
@@ -216,16 +215,7 @@ def __init__(
216215
service_account_file=None,
217216
):
218217
# type: (...) -> None
219-
self._client_session = ClientSession(
220-
catalog,
221-
schema,
222-
source,
223-
user,
224-
session_properties,
225-
http_headers,
226-
transaction_id,
227-
)
228-
218+
self._client_session = client_session
229219
self._host = host
230220
self._port = port
231221
self._next_uri = None # type: Optional[Text]
@@ -539,18 +529,25 @@ def execute(self):
539529

540530
response = self._request.post(self._sql)
541531
status = self._request.process(response)
532+
if status.next_uri is None:
533+
self._finished = True
542534
self.query_id = status.id
543-
self._stats.update({u"queryId": self.query_id})
535+
self._stats.update({"queryId": self.query_id})
544536
self._stats.update(status.stats)
545537
self._warnings = getattr(status, "warnings", [])
546-
if status.next_uri is None:
547-
self._finished = True
548538
self._result = PrestoResult(self, status.rows)
539+
while (
540+
not self._finished and not self._cancelled
541+
):
542+
self._result._rows += self.fetch()
549543
return self._result
550544

551545
def fetch(self):
552546
# type: () -> List[List[Any]]
553547
"""Continue fetching data for the current query_id"""
548+
if self._request.next_uri is None:
549+
self._finished = True
550+
return []
554551
response = self._request.get(self._request.next_uri)
555552
status = self._request.process(response)
556553
if status.columns:

prestodb/dbapi.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@
1717
Fetch methods returns rows as a list of lists on purpose to let the caller
1818
decide to convert then to a list of tuples.
1919
"""
20-
from __future__ import absolute_import
21-
from __future__ import division
22-
from __future__ import print_function
20+
from __future__ import absolute_import, division, print_function
2321

22+
import binascii
23+
import datetime
2424
import logging
25+
import uuid
2526
from typing import Any, List, Optional # NOQA for mypy types
26-
import datetime
2727

28-
from prestodb import constants
29-
import prestodb.exceptions
3028
import prestodb.client
29+
import prestodb.exceptions
3130
import prestodb.redirect
32-
from prestodb.transaction import Transaction, IsolationLevel, NO_TRANSACTION
3331

32+
from prestodb import constants
33+
from prestodb.transaction import IsolationLevel, NO_TRANSACTION, Transaction
34+
from prestodb.transaction import NO_TRANSACTION
3435

3536
__all__ = ["connect", "Connection", "Cursor"]
3637

@@ -75,6 +76,7 @@ def __init__(
7576
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
7677
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
7778
isolation_level=IsolationLevel.AUTOCOMMIT,
79+
**kwargs,
7880
):
7981
self.host = host
8082
self.port = port
@@ -83,6 +85,15 @@ def __init__(
8385
self.catalog = catalog
8486
self.schema = schema
8587
self.session_properties = session_properties
88+
self._client_session = prestodb.client.ClientSession(
89+
user,
90+
catalog,
91+
schema,
92+
source,
93+
session_properties,
94+
http_headers,
95+
NO_TRANSACTION,
96+
)
8697
# mypy cannot follow module import
8798
self._http_session = prestodb.client.PrestoRequest.http.Session()
8899
self.http_headers = http_headers
@@ -141,11 +152,7 @@ def _create_request(self):
141152
return prestodb.client.PrestoRequest(
142153
self.host,
143154
self.port,
144-
self.user,
145-
self.source,
146-
self.catalog,
147-
self.schema,
148-
self.session_properties,
155+
self._client_session,
149156
self._http_session,
150157
self.http_headers,
151158
NO_TRANSACTION,
@@ -187,6 +194,9 @@ def __init__(self, connection, request):
187194
self._iterator = None
188195
self._query = None
189196

197+
def __iter__(self):
198+
return self._iterator
199+
190200
@property
191201
def connection(self):
192202
return self._connection

tests/test_client.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import time
2121

2222
from requests_kerberos.exceptions import KerberosExchangeError
23-
from prestodb.client import PrestoRequest
23+
from prestodb.client import ClientSession, PrestoRequest
2424
from prestodb.auth import KerberosAuthentication
2525
from prestodb import constants
2626
import prestodb.exceptions
@@ -270,12 +270,14 @@ def test_presto_initial_request(monkeypatch):
270270
req = PrestoRequest(
271271
host="coordinator",
272272
port=8080,
273-
user="test",
274-
source="test",
275-
catalog="test",
276-
schema="test",
277-
http_scheme="http",
278-
session_properties={},
273+
client_session=ClientSession(
274+
user="test",
275+
source="test",
276+
catalog="test",
277+
schema="test",
278+
properties={},
279+
),
280+
http_scheme="http"
279281
)
280282

281283
http_resp = PrestoRequest.http.Response()
@@ -318,16 +320,17 @@ def test_request_headers(monkeypatch):
318320
req = PrestoRequest(
319321
host="coordinator",
320322
port=8080,
321-
user=user,
322-
source=source,
323-
catalog=catalog,
324-
schema=schema,
323+
client_session=ClientSession(
324+
user=user,
325+
source=source,
326+
catalog=catalog,
327+
schema=schema,
328+
properties={"hash_partition_count": 500, "needs_url_encoding": 'foo,bar'},
329+
headers={
330+
accept_encoding_header: accept_encoding_value,
331+
client_info_header: client_info_value,
332+
}),
325333
http_scheme="http",
326-
session_properties={"hash_partition_count": 500, "needs_url_encoding": 'foo,bar'},
327-
http_headers={
328-
accept_encoding_header: accept_encoding_value,
329-
client_info_header: client_info_value,
330-
},
331334
redirect_handler=None,
332335
)
333336

@@ -353,8 +356,9 @@ def test_request_invalid_http_headers():
353356
PrestoRequest(
354357
host="coordinator",
355358
port=8080,
356-
user="test",
357-
http_headers={constants.HEADER_USER: "invalid_header"},
359+
client_session=ClientSession(
360+
user="test",
361+
headers={constants.HEADER_USER: "invalid_header"})
358362
)
359363
assert str(value_error.value).startswith("cannot override reserved HTTP header")
360364

@@ -379,7 +383,8 @@ def long_call(request, uri, headers):
379383
req = PrestoRequest(
380384
host=host,
381385
port=port,
382-
user="test",
386+
client_session = ClientSession(
387+
user="test"),
383388
http_scheme=http_scheme,
384389
max_attempts=1,
385390
request_timeout=request_timeout,
@@ -401,12 +406,13 @@ def test_presto_fetch_request(monkeypatch):
401406
req = PrestoRequest(
402407
host="coordinator",
403408
port=8080,
404-
user="test",
405-
source="test",
406-
catalog="test",
407-
schema="test",
409+
client_session = ClientSession(
410+
user="test",
411+
source="test",
412+
catalog="test",
413+
schema="test",
414+
properties={}),
408415
http_scheme="http",
409-
session_properties={},
410416
)
411417

412418
http_resp = PrestoRequest.http.Response()
@@ -424,12 +430,13 @@ def test_presto_fetch_error(monkeypatch):
424430
req = PrestoRequest(
425431
host="coordinator",
426432
port=8080,
427-
user="test",
428-
source="test",
429-
catalog="test",
430-
schema="test",
433+
client_session = ClientSession(
434+
user="test",
435+
source="test",
436+
catalog="test",
437+
schema="test",
438+
properties={}),
431439
http_scheme="http",
432-
session_properties={},
433440
)
434441

435442
http_resp = PrestoRequest.http.Response()
@@ -465,12 +472,13 @@ def test_presto_connection_error(monkeypatch, error_code, error_type, error_mess
465472
req = PrestoRequest(
466473
host="coordinator",
467474
port=8080,
468-
user="test",
469-
source="test",
470-
catalog="test",
471-
schema="test",
475+
client_session = ClientSession(
476+
user="test",
477+
source="test",
478+
catalog="test",
479+
schema="test",
480+
properties={}),
472481
http_scheme="http",
473-
session_properties={},
474482
)
475483

476484
http_resp = PrestoRequest.http.Response()
@@ -511,7 +519,8 @@ def test_authentication_fail_retry(monkeypatch):
511519
req = PrestoRequest(
512520
host="coordinator",
513521
port=8080,
514-
user="test",
522+
client_session = ClientSession(
523+
user="test"),
515524
http_scheme=constants.HTTPS,
516525
auth=kerberos_auth,
517526
max_attempts=attempts,
@@ -538,7 +547,7 @@ def test_503_error_retry(monkeypatch):
538547

539548
attempts = 3
540549
req = PrestoRequest(
541-
host="coordinator", port=8080, user="test", max_attempts=attempts
550+
host="coordinator", port=8080, client_session=ClientSession(user="test"), max_attempts=attempts
542551
)
543552

544553
req.post("URL")
@@ -576,7 +585,7 @@ def test_gateway_redirect(monkeypatch):
576585
socket, "gethostbyaddr", lambda *args: ("finalhost", ["finalhost"], "1.2.3.4")
577586
)
578587

579-
req = PrestoRequest(host="coordinator", port=8080, user="test")
588+
req = PrestoRequest(host="coordinator", port=8080, client_session=ClientSession(user="test"))
580589
result = req.post("http://host:80/path/")
581590
assert gateway_response.count == 3
582591
assert result.ok

0 commit comments

Comments
 (0)