66import copy
77import csv
88from datetime import datetime
9+ import ast
910
1011import pandas as pd # type: ignore
1112import pyarrow as pa # type: ignore
@@ -416,6 +417,33 @@ def _read_csv_once(
416417 buff .close ()
417418 return dataframe
418419
420+ def _get_query_dtype (self , query_execution_id : str ) -> Tuple [Dict [str , str ], List [str ], List [str ], Dict [str , Any ]]:
421+ cols_metadata : Dict [str , str ] = self ._session .athena .get_query_columns_metadata (
422+ query_execution_id = query_execution_id )
423+ logger .debug (f"cols_metadata: { cols_metadata } " )
424+ dtype : Dict [str , str ] = {}
425+ parse_timestamps : List [str ] = []
426+ parse_dates : List [str ] = []
427+ converters : Dict [str , Any ] = {}
428+ col_name : str
429+ col_type : str
430+ for col_name , col_type in cols_metadata .items ():
431+ pandas_type : str = data_types .athena2pandas (dtype = col_type )
432+ if pandas_type in ["datetime64" , "date" ]:
433+ parse_timestamps .append (col_name )
434+ if pandas_type == "date" :
435+ parse_dates .append (col_name )
436+ elif pandas_type == "literal_eval" :
437+ converters [col_name ] = ast .literal_eval
438+ elif pandas_type == "bool" :
439+ logger .debug (f"Ignoring bool column: { col_name } " )
440+ else :
441+ dtype [col_name ] = pandas_type
442+ logger .debug (f"dtype: { dtype } " )
443+ logger .debug (f"parse_timestamps: { parse_timestamps } " )
444+ logger .debug (f"parse_dates: { parse_dates } " )
445+ return dtype , parse_timestamps , parse_dates , converters
446+
419447 def read_sql_athena (self , sql , database , s3_output = None , max_result_size = None ):
420448 """
421449 Executes any SQL query on AWS Athena and return a Dataframe of the result.
@@ -436,7 +464,7 @@ def read_sql_athena(self, sql, database, s3_output=None, max_result_size=None):
436464 message_error = f"Query error: { reason } "
437465 raise AthenaQueryError (message_error )
438466 else :
439- dtype , parse_timestamps , parse_dates , converters = self ._session . athena . get_query_dtype (
467+ dtype , parse_timestamps , parse_dates , converters = self ._get_query_dtype (
440468 query_execution_id = query_execution_id )
441469 path = f"{ s3_output } { query_execution_id } .csv"
442470 ret = self .read_csv (path = path ,
0 commit comments