33from contextlib import contextmanager
44from dataclasses import dataclass
55from itertools import chain , repeat
6- from typing import Callable , Dict , Mapping , Optional
6+ from typing import Any , Callable , Dict , Mapping , Optional , Tuple
77
8+ import agate
89import dbt .exceptions
910import pyodbc
1011from azure .core .credentials import AccessToken
1718)
1819from dbt .adapters .base import Credentials
1920from dbt .adapters .sql import SQLConnectionManager
20- from dbt .contracts .connection import AdapterResponse
21+ from dbt .clients .agate_helper import empty_table
22+ from dbt .contracts .connection import AdapterResponse , Connection , ConnectionState
2123from dbt .events import AdapterLogger
2224
2325from dbt .adapters .sqlserver import __version__
@@ -46,6 +48,7 @@ class SQLServerCredentials(Credentials):
4648 authentication : Optional [str ] = "sql"
4749 encrypt : Optional [bool ] = False
4850 trust_cert : Optional [bool ] = False
51+ retries : int = 1
4952
5053 _ALIASES = {
5154 "user" : "UID" ,
@@ -287,7 +290,6 @@ def exception_handler(self, sql):
287290 self .release ()
288291 except pyodbc .Error :
289292 logger .debug ("Failed to release connection!" )
290- pass
291293
292294 raise dbt .exceptions .DatabaseException (str (e ).strip ()) from e
293295
@@ -304,69 +306,73 @@ def exception_handler(self, sql):
304306 raise dbt .exceptions .RuntimeException (e )
305307
306308 @classmethod
307- def open (cls , connection ) :
309+ def open (cls , connection : Connection ) -> Connection :
308310
309- if connection .state == "open" :
311+ if connection .state == ConnectionState . OPEN :
310312 logger .debug ("Connection is already open, skipping open." )
311313 return connection
312314
313- credentials = connection .credentials
315+ credentials = cls . get_credentials ( connection .credentials )
314316
315- try :
316- con_str = []
317- con_str .append (f"DRIVER={{{ credentials .driver } }}" )
318-
319- if "\\ " in credentials .host :
317+ con_str = [f"DRIVER={{{ credentials .driver } }}" ]
320318
321- # If there is a backslash \ in the host name, the host is a
322- # SQL Server named instance. In this case then port number has to be omitted.
323- con_str .append (f"SERVER={ credentials .host } " )
324- else :
325- con_str .append (f"SERVER={ credentials .host } ,{ credentials .port } " )
319+ if "\\ " in credentials .host :
326320
327- con_str .append (f"Database={ credentials .database } " )
321+ # If there is a backslash \ in the host name, the host is a
322+ # SQL Server named instance. In this case then port number has to be omitted.
323+ con_str .append (f"SERVER={ credentials .host } " )
324+ else :
325+ con_str .append (f"SERVER={ credentials .host } ,{ credentials .port } " )
328326
329- type_auth = getattr ( credentials , "authentication" , "sql " )
327+ con_str . append ( f"Database= { credentials . database } " )
330328
331- if "ActiveDirectory" in type_auth :
332- con_str .append (f"Authentication={ credentials .authentication } " )
329+ type_auth = getattr (credentials , "authentication" , "sql" )
333330
334- if type_auth == "ActiveDirectoryPassword" :
335- con_str .append (f"UID={{{ credentials .UID } }}" )
336- con_str .append (f"PWD={{{ credentials .PWD } }}" )
337- elif type_auth == "ActiveDirectoryInteractive" :
338- con_str .append (f"UID={{{ credentials .UID } }}" )
331+ if "ActiveDirectory" in type_auth :
332+ con_str .append (f"Authentication={ credentials .authentication } " )
339333
340- elif getattr (credentials , "windows_login" , False ):
341- con_str .append ("trusted_connection=yes" )
342- elif type_auth == "sql" :
334+ if type_auth == "ActiveDirectoryPassword" :
343335 con_str .append (f"UID={{{ credentials .UID } }}" )
344336 con_str .append (f"PWD={{{ credentials .PWD } }}" )
337+ elif type_auth == "ActiveDirectoryInteractive" :
338+ con_str .append (f"UID={{{ credentials .UID } }}" )
339+
340+ elif getattr (credentials , "windows_login" , False ):
341+ con_str .append ("trusted_connection=yes" )
342+ elif type_auth == "sql" :
343+ con_str .append (f"UID={{{ credentials .UID } }}" )
344+ con_str .append (f"PWD={{{ credentials .PWD } }}" )
345+
346+ # still confused whether to use "Yes", "yes", "True", or "true"
347+ # to learn more visit
348+ # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
349+ if getattr (credentials , "encrypt" , False ) is True :
350+ con_str .append ("Encrypt=Yes" )
351+ if getattr (credentials , "trust_cert" , False ) is True :
352+ con_str .append ("TrustServerCertificate=Yes" )
345353
346- # still confused whether to use "Yes", "yes", "True", or "true"
347- # to learn more visit
348- # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
349- if getattr (credentials , "encrypt" , False ) is True :
350- con_str .append ("Encrypt=Yes" )
351- if getattr (credentials , "trust_cert" , False ) is True :
352- con_str .append ("TrustServerCertificate=Yes" )
354+ plugin_version = __version__ .version
355+ application_name = f"dbt-{ credentials .type } /{ plugin_version } "
356+ con_str .append (f"Application Name={ application_name } " )
353357
354- plugin_version = __version__ .version
355- application_name = f"dbt-{ credentials .type } /{ plugin_version } "
356- con_str .append (f"Application Name={ application_name } " )
358+ con_str_concat = ";" .join (con_str )
357359
358- con_str_concat = ";" .join (con_str )
360+ index = []
361+ for i , elem in enumerate (con_str ):
362+ if "pwd=" in elem .lower ():
363+ index .append (i )
359364
360- index = []
361- for i , elem in enumerate (con_str ):
362- if "pwd=" in elem .lower ():
363- index .append (i )
365+ if len (index ) != 0 :
366+ con_str [index [0 ]] = "PWD=***"
364367
365- if len (index ) != 0 :
366- con_str [index [0 ]] = "PWD=***"
368+ con_str_display = ";" .join (con_str )
367369
368- con_str_display = ";" .join (con_str )
370+ retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions
371+ pyodbc .InternalError , # not used according to docs, but defined in PEP-249
372+ pyodbc .OperationalError ,
373+ ]
369374
375+ def connect ():
370376 logger .debug (f"Using connection string: { con_str_display } " )
371377
372378 attrs_before = get_pyodbc_attrs_before (credentials )
@@ -375,24 +381,19 @@ def open(cls, connection):
375381 attrs_before = attrs_before ,
376382 autocommit = True ,
377383 )
378-
379- connection .state = "open"
380- connection .handle = handle
381384 logger .debug (f"Connected to db: { credentials .database } " )
385+ return handle
386+
387+ return cls .retry_connection (
388+ connection ,
389+ connect = connect ,
390+ logger = logger ,
391+ retry_limit = credentials .retries ,
392+ retryable_exceptions = retryable_exceptions ,
393+ )
382394
383- except pyodbc .Error as e :
384- logger .debug (f"Could not connect to db: { e } " )
385-
386- connection .handle = None
387- connection .state = "fail"
388-
389- raise dbt .exceptions .FailedToConnectException (str (e ))
390-
391- return connection
392-
393- def cancel (self , connection ):
395+ def cancel (self , connection : Connection ):
394396 logger .debug ("Cancel query" )
395- pass
396397
397398 def add_begin_query (self ):
398399 # return self.add_query('BEGIN TRANSACTION', auto_begin=False)
@@ -402,7 +403,13 @@ def add_commit_query(self):
402403 # return self.add_query('COMMIT TRANSACTION', auto_begin=False)
403404 pass
404405
405- def add_query (self , sql , auto_begin = True , bindings = None , abridge_sql_log = False ):
406+ def add_query (
407+ self ,
408+ sql : str ,
409+ auto_begin : bool = True ,
410+ bindings : Optional [Any ] = None ,
411+ abridge_sql_log : bool = False ,
412+ ) -> Tuple [Connection , Any ]:
406413
407414 connection = self .get_thread_connection ()
408415
@@ -435,11 +442,11 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
435442 return connection , cursor
436443
437444 @classmethod
438- def get_credentials (cls , credentials ) :
445+ def get_credentials (cls , credentials : SQLServerCredentials ) -> SQLServerCredentials :
439446 return credentials
440447
441448 @classmethod
442- def get_response (cls , cursor ) -> AdapterResponse :
449+ def get_response (cls , cursor : Any ) -> AdapterResponse :
443450 # message = str(cursor.statusmessage)
444451 message = "OK"
445452 rows = cursor .rowcount
@@ -456,7 +463,9 @@ def get_response(cls, cursor) -> AdapterResponse:
456463 rows_affected = rows ,
457464 )
458465
459- def execute (self , sql , auto_begin = True , fetch = False ):
466+ def execute (
467+ self , sql : str , auto_begin : bool = True , fetch : bool = False
468+ ) -> Tuple [AdapterResponse , agate .Table ]:
460469 _ , cursor = self .add_query (sql , auto_begin )
461470 response = self .get_response (cursor )
462471 if fetch :
@@ -466,7 +475,7 @@ def execute(self, sql, auto_begin=True, fetch=False):
466475 break
467476 table = self .get_result_from_cursor (cursor )
468477 else :
469- table = dbt . clients . agate_helper . empty_table ()
478+ table = empty_table ()
470479 # Step through all result sets so we process all errors
471480 while cursor .nextset ():
472481 pass
0 commit comments