Skip to content

Commit aef7001

Browse files
authored
Merge pull request #139 from awslabs/to_csv_filename
Add filename and header arguments for Pandas.to_csv().
2 parents 378badb + 4f63d7b commit aef7001

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

awswrangler/pandas.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,13 @@ def _apply_dates_to_generator(generator, parse_dates):
511511
def to_csv(self,
512512
dataframe: pd.DataFrame,
513513
path: str,
514+
filename: Optional[str] = None,
514515
sep: Optional[str] = None,
515516
na_rep: Optional[str] = None,
516517
columns: Optional[List[str]] = None,
517518
quoting: Optional[int] = None,
518519
escapechar: Optional[str] = None,
520+
header: Union[bool, List[str]] = False,
519521
serde: Optional[str] = "OpenCSVSerDe",
520522
database: Optional[str] = None,
521523
table: Optional[str] = None,
@@ -533,11 +535,13 @@ def to_csv(self,
533535
534536
:param dataframe: Pandas Dataframe
535537
:param path: Amazon S3 path (e.g. s3://bucket_name/folder_name/)
538+
:param filename: The default behavior writes several files with random names, but if you prefer pass a filename it will disable the parallelism and will write a single file with the desired name. NOT VALID for Glue catalog integration.
536539
:param sep: Same as pandas.to_csv()
537540
:param na_rep: Same as pandas.to_csv()
538541
:param columns: Same as pandas.to_csv()
539542
:param quoting: Same as pandas.to_csv()
540543
:param escapechar: Same as pandas.to_csv()
544+
:param header: Same as pandas.to_csv(). NOT VALID for Glue catalog integration.
541545
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only)
542546
:param database: AWS Glue Database name
543547
:param table: AWS Glue table name
@@ -552,17 +556,23 @@ def to_csv(self,
552556
:param columns_comments: Columns names and the related comments (Optional[Dict[str, str]])
553557
:return: List of objects written on S3
554558
"""
559+
if (filename is not None) or (header is not False):
560+
database = None
561+
table = None
562+
procs_cpu_bound = 1
555563
if (serde is not None) and (serde not in Pandas.VALID_CSV_SERDES):
556564
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
557565
if (database is not None) and (serde is None):
558566
raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.")
559-
extra_args: Dict[str, Optional[Union[str, int, List[str]]]] = {
567+
extra_args: Dict[str, Optional[Union[str, int, bool, List[str]]]] = {
568+
"filename": filename,
560569
"sep": sep,
561570
"na_rep": na_rep,
562571
"columns": columns,
563572
"serde": serde,
564573
"escapechar": escapechar,
565-
"quoting": quoting
574+
"quoting": quoting,
575+
"header": header
566576
}
567577
return self.to_s3(dataframe=dataframe,
568578
path=path,
@@ -925,7 +935,11 @@ def _data_to_s3_object_writer(dataframe: pd.DataFrame,
925935
if file_format == "parquet":
926936
outfile: str = f"{guid}{compression_extension}.parquet"
927937
elif file_format == "csv":
928-
outfile = f"{guid}{compression_extension}.csv"
938+
filename: Optional[str] = extra_args.get("filename")
939+
if filename is None:
940+
outfile = f"{guid}{compression_extension}.csv"
941+
else:
942+
outfile = filename
929943
else:
930944
raise UnsupportedFileFormat(file_format)
931945
object_path: str = "/".join([path, outfile])
@@ -975,7 +989,7 @@ def _write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra
975989
csv_extra_args["quoting"] = csv.QUOTE_NONE
976990
csv_extra_args["escapechar"] = "\\"
977991
csv_buffer: bytes = bytes(
978-
dataframe.to_csv(None, header=False, index=preserve_index, compression=compression, **csv_extra_args),
992+
dataframe.to_csv(None, header=extra_args.get("header"), index=preserve_index, compression=compression, **csv_extra_args),
979993
"utf-8")
980994
Pandas._write_csv_to_s3_retrying(fs=fs, path=path, buffer=csv_buffer)
981995

testing/test_awswrangler/test_pandas.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2552,3 +2552,19 @@ def test_read_parquet_int_na(bucket):
25522552
assert len(df2.index) == 10_001
25532553
assert len(df2.columns) == 1
25542554
assert df2.dtypes["col"] == "Int64"
2555+
2556+
2557+
def test_to_csv_header_filename(bucket):
2558+
path = f"s3://{bucket}/test_to_csv_header_filename/"
2559+
df = pd.DataFrame({"col1": [1, 2], "col2": ["foo", "boo"]})
2560+
paths = wr.pandas.to_csv(
2561+
dataframe=df,
2562+
path=path,
2563+
filename="file.csv",
2564+
header=True,
2565+
preserve_index=False
2566+
)
2567+
assert len(paths) == 1
2568+
assert paths[0].endswith("/test_to_csv_header_filename/file.csv")
2569+
df2 = wr.pandas.read_csv(path=paths[0])
2570+
assert df.equals(df2)

0 commit comments

Comments
 (0)