@@ -242,10 +242,126 @@ def setoutputsize(self, size, column):
242242 raise prestodb .exceptions .NotSupportedError
243243
244244 def execute (self , operation , params = None ):
245- self ._query = prestodb .client .PrestoQuery (self ._request , sql = operation )
246- result = self ._query .execute ()
247- self ._iterator = iter (result )
248- return result
245+ if params :
246+ assert isinstance (params , (list , tuple )), (
247+ "params must be a list or tuple containing the query "
248+ "parameter values"
249+ )
250+
251+ statement_name = self ._generate_unique_statement_name ()
252+ self ._prepare_statement (operation , statement_name )
253+
254+ try :
255+ # Send execute statement and assign the return value to `results`
256+ # as it will be returned by the function
257+ self ._query = self ._execute_prepared_statement (statement_name , params )
258+ self ._iterator = iter (self ._query .execute ())
259+ finally :
260+ # Send deallocate statement
261+ # At this point the query can be deallocated since it has already
262+ # been executed
263+ # TODO: Consider caching prepared statements if requested by caller
264+ self ._deallocate_prepared_statement (statement_name )
265+ else :
266+ self ._query = prestodb .client .PrestoQuery (self ._request , sql = operation )
267+ self ._iterator = iter (self ._query .execute ())
268+ return self
269+
270+ def _generate_unique_statement_name (self ):
271+ return "st_" + uuid .uuid4 ().hex .replace ("-" , "" )
272+
273+ def _prepare_statement (self , statement : str , name : str ) -> None :
274+ sql = f"PREPARE { name } FROM { statement } "
275+ query = prestodb .client .PrestoQuery (self ._request , sql = sql )
276+ query .execute ()
277+
278+ def _execute_prepared_statement (self , statement_name , params ):
279+ sql = (
280+ "EXECUTE "
281+ + statement_name
282+ + " USING "
283+ + "," .join (map (self ._format_prepared_param , params ))
284+ )
285+ return prestodb .client .PrestoQuery (self ._request , sql = sql )
286+
287+ def _deallocate_prepared_statement (self , statement_name : str ) -> None :
288+ sql = "DEALLOCATE PREPARE " + statement_name
289+ query = prestodb .client .PrestoQuery (self ._request , sql = sql )
290+ query .execute ()
291+
292+ def _format_prepared_param (self , param ):
293+ """
294+ Formats parameters to be passed in an
295+ EXECUTE statement.
296+ """
297+ if param is None :
298+ return "NULL"
299+
300+ if isinstance (param , bool ):
301+ return "true" if param else "false"
302+
303+ if isinstance (param , int ):
304+ # TODO represent numbers exceeding 64-bit (BIGINT) as DECIMAL
305+ return "%d" % param
306+
307+ if isinstance (param , float ):
308+ if param == float ("+inf" ):
309+ return "infinity()"
310+ if param == float ("-inf" ):
311+ return "-infinity()"
312+ return "DOUBLE '%s'" % param
313+
314+ if isinstance (param , str ):
315+ return "'%s'" % param .replace ("'" , "''" )
316+
317+ if isinstance (param , bytes ):
318+ return "X'%s'" % param .hex ()
319+
320+ if isinstance (param , datetime .datetime ) and param .tzinfo is None :
321+ datetime_str = param .strftime ("%Y-%m-%d %H:%M:%S.%f" )
322+ return "TIMESTAMP '%s'" % datetime_str
323+
324+ if isinstance (param , datetime .datetime ) and param .tzinfo is not None :
325+ datetime_str = param .strftime ("%Y-%m-%d %H:%M:%S.%f" )
326+ # offset-based timezones
327+ return "TIMESTAMP '%s %s'" % (datetime_str , param .tzinfo .tzname (param ))
328+
329+ # We can't calculate the offset for a time without a point in time
330+ if isinstance (param , datetime .time ) and param .tzinfo is None :
331+ time_str = param .strftime ("%H:%M:%S.%f" )
332+ return "TIME '%s'" % time_str
333+
334+ if isinstance (param , datetime .time ) and param .tzinfo is not None :
335+ time_str = param .strftime ("%H:%M:%S.%f" )
336+ # offset-based timezones
337+ return "TIME '%s %s'" % (time_str , param .strftime ("%Z" )[3 :])
338+
339+ if isinstance (param , datetime .date ):
340+ date_str = param .strftime ("%Y-%m-%d" )
341+ return "DATE '%s'" % date_str
342+
343+ if isinstance (param , list ):
344+ return "ARRAY[%s]" % "," .join (map (self ._format_prepared_param , param ))
345+
346+ if isinstance (param , tuple ):
347+ return "ROW(%s)" % "," .join (map (self ._format_prepared_param , param ))
348+
349+ if isinstance (param , dict ):
350+ keys = list (param .keys ())
351+ values = [param [key ] for key in keys ]
352+ return "MAP({}, {})" .format (
353+ self ._format_prepared_param (keys ), self ._format_prepared_param (values )
354+ )
355+
356+ if isinstance (param , uuid .UUID ):
357+ return "UUID '%s'" % param
358+
359+ if isinstance (param , (bytes , bytearray )):
360+ return "X'%s'" % binascii .hexlify (param ).decode ("utf-8" )
361+
362+ raise prestodb .exceptions .NotSupportedError (
363+ "Query parameter of type '%s' is not supported." % type (param )
364+ )
249365
250366 def executemany (self , operation , seq_of_params ):
251367 raise prestodb .exceptions .NotSupportedError
0 commit comments