@@ -50,6 +50,7 @@ def read_csv(
5050 max_result_size = None ,
5151 header = "infer" ,
5252 names = None ,
53+ usecols = None ,
5354 dtype = None ,
5455 sep = "," ,
5556 lineterminator = "\n " ,
@@ -71,6 +72,7 @@ def read_csv(
7172 :param max_result_size: Max number of bytes on each request to S3
7273 :param header: Same as pandas.read_csv()
7374 :param names: Same as pandas.read_csv()
75+ :param usecols: Same as pandas.read_csv()
7476 :param dtype: Same as pandas.read_csv()
7577 :param sep: Same as pandas.read_csv()
7678 :param lineterminator: Same as pandas.read_csv()
@@ -96,6 +98,7 @@ def read_csv(
9698 max_result_size = max_result_size ,
9799 header = header ,
98100 names = names ,
101+ usecols = usecols ,
99102 dtype = dtype ,
100103 sep = sep ,
101104 lineterminator = lineterminator ,
@@ -113,6 +116,7 @@ def read_csv(
113116 key_path = key_path ,
114117 header = header ,
115118 names = names ,
119+ usecols = usecols ,
116120 dtype = dtype ,
117121 sep = sep ,
118122 lineterminator = lineterminator ,
@@ -133,6 +137,7 @@ def _read_csv_iterator(
133137 max_result_size = 200_000_000 , # 200 MB
134138 header = "infer" ,
135139 names = None ,
140+ usecols = None ,
136141 dtype = None ,
137142 sep = "," ,
138143 lineterminator = "\n " ,
@@ -155,6 +160,7 @@ def _read_csv_iterator(
155160 :param max_result_size: Max number of bytes on each request to S3
156161 :param header: Same as pandas.read_csv()
157162 :param names: Same as pandas.read_csv()
163+ :param usecols: Same as pandas.read_csv()
158164 :param dtype: Same as pandas.read_csv()
159165 :param sep: Same as pandas.read_csv()
160166 :param lineterminator: Same as pandas.read_csv()
@@ -182,6 +188,7 @@ def _read_csv_iterator(
182188 key_path = key_path ,
183189 header = header ,
184190 names = names ,
191+ usecols = usecols ,
185192 dtype = dtype ,
186193 sep = sep ,
187194 lineterminator = lineterminator ,
@@ -235,6 +242,7 @@ def _read_csv_iterator(
235242 StringIO (body [:last_char ].decode ("utf-8" )),
236243 header = header ,
237244 names = names ,
245+ usecols = usecols ,
238246 sep = sep ,
239247 quotechar = quotechar ,
240248 quoting = quoting ,
@@ -353,6 +361,7 @@ def _read_csv_once(
353361 key_path ,
354362 header = "infer" ,
355363 names = None ,
364+ usecols = None ,
356365 dtype = None ,
357366 sep = "," ,
358367 lineterminator = "\n " ,
@@ -374,6 +383,7 @@ def _read_csv_once(
374383 :param key_path: S3 key path (W/o bucket)
375384 :param header: Same as pandas.read_csv()
376385 :param names: Same as pandas.read_csv()
386+ :param usecols: Same as pandas.read_csv()
377387 :param dtype: Same as pandas.read_csv()
378388 :param sep: Same as pandas.read_csv()
379389 :param lineterminator: Same as pandas.read_csv()
@@ -395,6 +405,7 @@ def _read_csv_once(
395405 buff ,
396406 header = header ,
397407 names = names ,
408+ usecols = usecols ,
398409 sep = sep ,
399410 quotechar = quotechar ,
400411 quoting = quoting ,
@@ -714,7 +725,8 @@ def _data_to_s3_dataset_writer(dataframe,
714725 session_primitives ,
715726 file_format ,
716727 cast_columns = None ,
717- extra_args = None ):
728+ extra_args = None ,
729+ isolated_dataframe = False ):
718730 objects_paths = []
719731 if not partition_cols :
720732 object_path = Pandas ._data_to_s3_object_writer (
@@ -725,7 +737,8 @@ def _data_to_s3_dataset_writer(dataframe,
725737 session_primitives = session_primitives ,
726738 file_format = file_format ,
727739 cast_columns = cast_columns ,
728- extra_args = extra_args )
740+ extra_args = extra_args ,
741+ isolated_dataframe = isolated_dataframe )
729742 objects_paths .append (object_path )
730743 else :
731744 for keys , subgroup in dataframe .groupby (partition_cols ):
@@ -744,7 +757,8 @@ def _data_to_s3_dataset_writer(dataframe,
744757 session_primitives = session_primitives ,
745758 file_format = file_format ,
746759 cast_columns = cast_columns ,
747- extra_args = extra_args )
760+ extra_args = extra_args ,
761+ isolated_dataframe = True )
748762 objects_paths .append (object_path )
749763 return objects_paths
750764
@@ -769,7 +783,8 @@ def _data_to_s3_dataset_writer_remote(send_pipe,
769783 session_primitives = session_primitives ,
770784 file_format = file_format ,
771785 cast_columns = cast_columns ,
772- extra_args = extra_args ))
786+ extra_args = extra_args ,
787+ isolated_dataframe = True ))
773788 send_pipe .close ()
774789
775790 @staticmethod
@@ -780,7 +795,8 @@ def _data_to_s3_object_writer(dataframe,
780795 session_primitives ,
781796 file_format ,
782797 cast_columns = None ,
783- extra_args = None ):
798+ extra_args = None ,
799+ isolated_dataframe = False ):
784800 fs = s3 .get_fs (session_primitives = session_primitives )
785801 fs = pyarrow .filesystem ._ensure_filesystem (fs )
786802 s3 .mkdir_if_not_exists (fs , path )
@@ -803,12 +819,14 @@ def _data_to_s3_object_writer(dataframe,
803819 raise UnsupportedFileFormat (file_format )
804820 object_path = "/" .join ([path , outfile ])
805821 if file_format == "parquet" :
806- Pandas .write_parquet_dataframe (dataframe = dataframe ,
807- path = object_path ,
808- preserve_index = preserve_index ,
809- compression = compression ,
810- fs = fs ,
811- cast_columns = cast_columns )
822+ Pandas .write_parquet_dataframe (
823+ dataframe = dataframe ,
824+ path = object_path ,
825+ preserve_index = preserve_index ,
826+ compression = compression ,
827+ fs = fs ,
828+ cast_columns = cast_columns ,
829+ isolated_dataframe = isolated_dataframe )
812830 elif file_format == "csv" :
813831 Pandas .write_csv_dataframe (dataframe = dataframe ,
814832 path = object_path ,
@@ -848,15 +866,17 @@ def write_csv_dataframe(dataframe,
848866
849867 @staticmethod
850868 def write_parquet_dataframe (dataframe , path , preserve_index , compression ,
851- fs , cast_columns ):
869+ fs , cast_columns , isolated_dataframe ):
852870 if not cast_columns :
853871 cast_columns = {}
854872
855873 # Casting on Pandas
874+ casted_in_pandas = []
856875 dtypes = copy .deepcopy (dataframe .dtypes .to_dict ())
857876 for name , dtype in dtypes .items ():
858877 if str (dtype ) == "Int64" :
859878 dataframe [name ] = dataframe [name ].astype ("float64" )
879+ casted_in_pandas .append (name )
860880 cast_columns [name ] = "bigint"
861881 logger .debug (f"Casting column { name } Int64 to float64" )
862882
@@ -885,6 +905,11 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression,
885905 coerce_timestamps = "ms" ,
886906 flavor = "spark" )
887907
908+ # Casting back on Pandas if necessary
909+ if isolated_dataframe is False :
910+ for col in casted_in_pandas :
911+ dataframe [col ] = dataframe [col ].astype ("Int64" )
912+
888913 def to_redshift (
889914 self ,
890915 dataframe ,
0 commit comments