55
66import pyarrow
77
8- from awswrangler .exceptions import UnsupportedType , UnsupportedFileFormat
8+ from awswrangler .exceptions import UnsupportedType , UnsupportedFileFormat , InvalidSerDe , ApiError
99
1010logger = logging .getLogger (__name__ )
1111
@@ -155,12 +155,11 @@ def metadata_to_glue(self,
155155 if partition_cols :
156156 partitions_tuples = Glue ._parse_partitions_tuples (
157157 objects_paths = objects_paths , partition_cols = partition_cols )
158- self .add_partitions (
159- database = database ,
160- table = table ,
161- partition_paths = partitions_tuples ,
162- file_format = file_format ,
163- )
158+ self .add_partitions (database = database ,
159+ table = table ,
160+ partition_paths = partitions_tuples ,
161+ file_format = file_format ,
162+ extra_args = extra_args )
164163
165164 def delete_table_if_exists (self , database , table ):
166165 try :
@@ -184,7 +183,8 @@ def create_table(self,
184183 partition_cols_schema = None ,
185184 extra_args = None ):
186185 if file_format == "parquet" :
187- table_input = Glue .parquet_table_definition (table , partition_cols_schema , schema , path )
186+ table_input = Glue .parquet_table_definition (
187+ table , partition_cols_schema , schema , path )
188188 elif file_format == "csv" :
189189 table_input = Glue .csv_table_definition (table ,
190190 partition_cols_schema ,
@@ -196,25 +196,31 @@ def create_table(self,
196196 self ._client_glue .create_table (DatabaseName = database ,
197197 TableInput = table_input )
198198
199- def add_partitions (self , database , table , partition_paths , file_format ):
199+ def add_partitions (self , database , table , partition_paths , file_format ,
200+ extra_args ):
200201 if not partition_paths :
201202 return None
202203 partitions = list ()
203204 for partition in partition_paths :
204205 if file_format == "parquet" :
205- partition_def = Glue .parquet_partition_definition (partition )
206+ partition_def = Glue .parquet_partition_definition (
207+ partition = partition )
206208 elif file_format == "csv" :
207- partition_def = Glue .csv_partition_definition (partition )
209+ partition_def = Glue .csv_partition_definition (
210+ partition = partition , extra_args = extra_args )
208211 else :
209212 raise UnsupportedFileFormat (file_format )
210213 partitions .append (partition_def )
211214 pages_num = int (ceil (len (partitions ) / 100.0 ))
212215 for _ in range (pages_num ):
213216 page = partitions [:100 ]
214217 del partitions [:100 ]
215- self ._client_glue .batch_create_partition (DatabaseName = database ,
216- TableName = table ,
217- PartitionInputList = page )
218+ res = self ._client_glue .batch_create_partition (
219+ DatabaseName = database ,
220+ TableName = table ,
221+ PartitionInputList = page )
222+ if len (res ["Errors" ]) > 0 :
223+ raise ApiError (f"{ res ['Errors' ][0 ]} " )
218224
219225 def get_connection_details (self , name ):
220226 return self ._client_glue .get_connection (
@@ -223,18 +229,25 @@ def get_connection_details(self, name):
223229 @staticmethod
224230 def _extract_pyarrow_schema (dataframe , preserve_index ):
225231 cols = []
232+ cols_dtypes = {}
226233 schema = []
234+
227235 for name , dtype in dataframe .dtypes .to_dict ().items ():
228236 dtype = str (dtype )
229237 if str (dtype ) == "Int64" :
230- schema . append (( name , "int64" ))
238+ cols_dtypes [ name ] = "int64"
231239 else :
232240 cols .append (name )
233241
234- # Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
235- schema += [(str (x .name ), str (x .type ))
236- for x in pyarrow .Schema .from_pandas (
237- df = dataframe [cols ], preserve_index = preserve_index )]
242+ for field in pyarrow .Schema .from_pandas (df = dataframe [cols ],
243+ preserve_index = preserve_index ):
244+ name = str (field .name )
245+ dtype = str (field .type )
246+ cols_dtypes [name ] = dtype
247+ if name not in dataframe .columns :
248+ schema .append ((name , dtype ))
249+
250+ schema += [(name , cols_dtypes [name ]) for name in dataframe .columns ]
238251 logger .debug (f"schema: { schema } " )
239252 return schema
240253
@@ -256,7 +269,8 @@ def _build_schema(dataframe, partition_cols, preserve_index):
256269 else :
257270 schema_built .append ((name , athena_type ))
258271
259- partition_cols_schema_built = [(name , partition_cols_types [name ]) for name in partition_cols ]
272+ partition_cols_schema_built = [(name , partition_cols_types [name ])
273+ for name in partition_cols ]
260274
261275 logger .debug (f"schema_built:\n { schema_built } " )
262276 logger .debug (
@@ -270,17 +284,40 @@ def _parse_table_name(path):
270284 return path .rpartition ("/" )[2 ]
271285
272286 @staticmethod
273- def csv_table_definition (table , partition_cols_schema , schema , path , extra_args ):
274- sep = extra_args [ "sep" ] if "sep" in extra_args else ","
287+ def csv_table_definition (table , partition_cols_schema , schema , path ,
288+ extra_args ):
275289 if not partition_cols_schema :
276290 partition_cols_schema = []
291+ sep = extra_args ["sep" ] if "sep" in extra_args else ","
292+ serde = extra_args .get ("serde" )
293+ if serde == "OpenCSVSerDe" :
294+ serde_fullname = "org.apache.hadoop.hive.serde2.OpenCSVSerde"
295+ param = {
296+ "separatorChar" : sep ,
297+ "quoteChar" : "\" " ,
298+ "escapeChar" : "\\ " ,
299+ }
300+ refined_par_schema = [(name , "string" )
301+ for name , dtype in partition_cols_schema ]
302+ refined_schema = [(name , "string" ) for name , dtype in schema ]
303+ elif serde == "LazySimpleSerDe" :
304+ serde_fullname = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
305+ param = {"field.delim" : sep , "escape.delim" : "\\ " }
306+ dtypes_allowed = ["int" , "bigint" , "float" , "double" ]
307+ refined_par_schema = [(name , dtype ) if dtype in dtypes_allowed else
308+ (name , "string" )
309+ for name , dtype in partition_cols_schema ]
310+ refined_schema = [(name , dtype ) if dtype in dtypes_allowed else
311+ (name , "string" ) for name , dtype in schema ]
312+ else :
313+ raise InvalidSerDe (f"{ serde } in not in the valid SerDe list." )
277314 return {
278315 "Name" :
279316 table ,
280317 "PartitionKeys" : [{
281318 "Name" : x [0 ],
282319 "Type" : x [1 ]
283- } for x in partition_cols_schema ],
320+ } for x in refined_par_schema ],
284321 "TableType" :
285322 "EXTERNAL_TABLE" ,
286323 "Parameters" : {
@@ -295,54 +332,61 @@ def csv_table_definition(table, partition_cols_schema, schema, path, extra_args)
295332 "Columns" : [{
296333 "Name" : x [0 ],
297334 "Type" : x [1 ]
298- } for x in schema ],
335+ } for x in refined_schema ],
299336 "Location" : path ,
300337 "InputFormat" : "org.apache.hadoop.mapred.TextInputFormat" ,
301338 "OutputFormat" :
302339 "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat" ,
303340 "Compressed" : False ,
304341 "NumberOfBuckets" : - 1 ,
305342 "SerdeInfo" : {
306- "Parameters" : {
307- "field.delim" : sep
308- },
309- "SerializationLibrary" :
310- "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" ,
343+ "Parameters" : param ,
344+ "SerializationLibrary" : serde_fullname ,
311345 },
312346 "StoredAsSubDirectories" : False ,
313347 "SortColumns" : [],
314348 "Parameters" : {
315349 "classification" : "csv" ,
316350 "compressionType" : "none" ,
317351 "typeOfData" : "file" ,
318- "delimiter" : "," ,
352+ "delimiter" : sep ,
319353 "columnsOrdered" : "true" ,
320354 "areColumnsQuoted" : "false" ,
321355 },
322356 },
323357 }
324358
325359 @staticmethod
326- def csv_partition_definition (partition ):
360+ def csv_partition_definition (partition , extra_args ):
361+ sep = extra_args ["sep" ] if "sep" in extra_args else ","
362+ serde = extra_args .get ("serde" )
363+ if serde == "OpenCSVSerDe" :
364+ serde_fullname = "org.apache.hadoop.hive.serde2.OpenCSVSerde"
365+ param = {
366+ "separatorChar" : sep ,
367+ "quoteChar" : "\" " ,
368+ "escapeChar" : "\\ " ,
369+ }
370+ elif serde == "LazySimpleSerDe" :
371+ serde_fullname = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
372+ param = {"field.delim" : sep , "escape.delim" : "\\ " }
373+ else :
374+ raise InvalidSerDe (f"{ serde } in not in the valid SerDe list." )
327375 return {
328376 "StorageDescriptor" : {
329377 "InputFormat" : "org.apache.hadoop.mapred.TextInputFormat" ,
330378 "Location" : partition [0 ],
331379 "SerdeInfo" : {
332- "Parameters" : {
333- "field.delim" : ","
334- },
335- "SerializationLibrary" :
336- "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" ,
380+ "Parameters" : param ,
381+ "SerializationLibrary" : serde_fullname ,
337382 },
338383 "StoredAsSubDirectories" : False ,
339384 },
340385 "Values" : partition [1 ],
341386 }
342387
343388 @staticmethod
344- def parquet_table_definition (table , partition_cols_schema ,
345- schema , path ):
389+ def parquet_table_definition (table , partition_cols_schema , schema , path ):
346390 if not partition_cols_schema :
347391 partition_cols_schema = []
348392 return {
0 commit comments