33import logging
44from datetime import datetime , date
55
6+ import pyarrow
7+
68from awswrangler .exceptions import UnsupportedType , UnsupportedFileFormat
79
810logger = logging .getLogger (__name__ )
@@ -43,6 +45,28 @@ def get_table_python_types(self, database, table):
4345 dtypes = self .get_table_athena_types (database = database , table = table )
4446 return {k : Glue .type_athena2python (v ) for k , v in dtypes .items ()}
4547
48+ @staticmethod
49+ def type_pyarrow2athena (dtype ):
50+ dtype = str (dtype ).lower ()
51+ if dtype == "int32" :
52+ return "int"
53+ elif dtype == "int64" :
54+ return "bigint"
55+ elif dtype == "float" :
56+ return "float"
57+ elif dtype == "double" :
58+ return "double"
59+ elif dtype == "bool" :
60+ return "boolean"
61+ elif dtype == "string" :
62+ return "string"
63+ elif dtype .startswith ("timestamp" ):
64+ return "timestamp"
65+ elif dtype .startswith ("date" ):
66+ return "date"
67+ else :
68+ raise UnsupportedType (f"Unsupported Pyarrow type: { dtype } " )
69+
4670 @staticmethod
4771 def type_pandas2athena (dtype ):
4872 dtype = dtype .lower ()
@@ -58,7 +82,7 @@ def type_pandas2athena(dtype):
5882 return "boolean"
5983 elif dtype == "object" :
6084 return "string"
61- elif dtype [: 10 ] == "datetime64" :
85+ elif dtype . startswith ( "datetime64" ) :
6286 return "timestamp"
6387 else :
6488 raise UnsupportedType (f"Unsupported Pandas type: { dtype } " )
@@ -113,8 +137,7 @@ def metadata_to_glue(self,
113137 extra_args = None ):
114138 schema = Glue ._build_schema (dataframe = dataframe ,
115139 partition_cols = partition_cols ,
116- preserve_index = preserve_index ,
117- cast_columns = cast_columns )
140+ preserve_index = preserve_index )
118141 table = table if table else Glue ._parse_table_name (path )
119142 table = table .lower ().replace ("." , "_" )
120143 if mode == "overwrite" :
@@ -198,31 +221,38 @@ def get_connection_details(self, name):
198221 Name = name , HidePassword = False )["Connection" ]
199222
200223 @staticmethod
201- def _build_schema (dataframe ,
202- partition_cols ,
203- preserve_index ,
204- cast_columns = None ):
224+ def _extract_pyarrow_schema (dataframe , preserve_index ):
225+ cols = []
226+ schema = []
227+ for name , dtype in dataframe .dtypes .to_dict ().items ():
228+ dtype = str (dtype )
229+ if str (dtype ) == "Int64" :
230+ schema .append ((name , "int64" ))
231+ else :
232+ cols .append (name )
233+
234+ # Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
235+ schema += [(str (x .name ), str (x .type ))
236+ for x in pyarrow .Schema .from_pandas (
237+ df = dataframe [cols ], preserve_index = preserve_index )]
238+ logger .debug (f"schema: { schema } " )
239+ return schema
240+
241+ @staticmethod
242+ def _build_schema (dataframe , partition_cols , preserve_index ):
205243 logger .debug (f"dataframe.dtypes:\n { dataframe .dtypes } " )
206244 if not partition_cols :
207245 partition_cols = []
246+
247+ pyarrow_schema = Glue ._extract_pyarrow_schema (
248+ dataframe = dataframe , preserve_index = preserve_index )
249+
208250 schema_built = []
209- if preserve_index :
210- name = str (
211- dataframe .index .name ) if dataframe .index .name else "index"
212- dataframe .index .name = "index"
213- dtype = str (dataframe .index .dtype )
214- if name not in partition_cols :
215- athena_type = Glue .type_pandas2athena (dtype )
216- schema_built .append ((name , athena_type ))
217- for col in dataframe .columns :
218- name = str (col )
219- if cast_columns and name in cast_columns :
220- dtype = cast_columns [name ]
221- else :
222- dtype = str (dataframe [name ].dtype )
251+ for name , dtype in pyarrow_schema :
223252 if name not in partition_cols :
224- athena_type = Glue .type_pandas2athena (dtype )
253+ athena_type = Glue .type_pyarrow2athena (dtype )
225254 schema_built .append ((name , athena_type ))
255+
226256 logger .debug (f"schema_built:\n { schema_built } " )
227257 return schema_built
228258
0 commit comments