Skip to content

Commit 96b6849

Browse files
authored
Merge branch 'main' into chore/unpin-arrow
2 parents 8247751 + f994d07 commit 96b6849

File tree

3 files changed

+83
-5
lines changed

3 files changed

+83
-5
lines changed

awswrangler/mysql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def connect(
9393
write_timeout: int | None = None,
9494
connect_timeout: int = 10,
9595
cursorclass: type["Cursor"] | None = None,
96-
) -> "pymysql.connections.Connection": # type: ignore[type-arg]
96+
) -> "pymysql.connections.Connection":
9797
"""Return a pymysql connection from a Glue Catalog Connection or Secrets Manager.
9898
9999
https://pymysql.readthedocs.io
@@ -231,7 +231,7 @@ def read_sql_query(
231231
@_utils.check_optional_dependency(pymysql, "pymysql")
232232
def read_sql_query(
233233
sql: str,
234-
con: "pymysql.connections.Connection", # type: ignore[type-arg]
234+
con: "pymysql.connections.Connection",
235235
index_col: str | list[str] | None = None,
236236
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = None,
237237
chunksize: int | None = None,
@@ -351,7 +351,7 @@ def read_sql_table(
351351
@_utils.check_optional_dependency(pymysql, "pymysql")
352352
def read_sql_table(
353353
table: str,
354-
con: "pymysql.connections.Connection", # type: ignore[type-arg]
354+
con: "pymysql.connections.Connection",
355355
schema: str | None = None,
356356
index_col: str | list[str] | None = None,
357357
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = None,
@@ -439,7 +439,7 @@ def read_sql_table(
439439
@apply_configs
440440
def to_sql(
441441
df: pd.DataFrame,
442-
con: "pymysql.connections.Connection", # type: ignore[type-arg]
442+
con: "pymysql.connections.Connection",
443443
table: str,
444444
schema: str,
445445
mode: _ToSqlModeLiteral = "append",

awswrangler/redshift/_read.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def unload_to_files(
241241
kms_key_id: str | None = None,
242242
manifest: bool = False,
243243
partition_cols: list[str] | None = None,
244+
cleanpath: bool = False,
244245
boto3_session: boto3.Session | None = None,
245246
) -> None:
246247
"""Unload Parquet files on s3 from a Redshift query result (Through the UNLOAD command).
@@ -294,6 +295,21 @@ def unload_to_files(
294295
Unload a manifest file on S3.
295296
partition_cols
296297
Specifies the partition keys for the unload operation.
298+
cleanpath
299+
Use CLEANPATH instead of ALLOWOVERWRITE. When True, uses CLEANPATH to remove existing files
300+
located in the Amazon S3 path before unloading files. When False (default), uses ALLOWOVERWRITE
301+
to overwrite existing files, including the manifest file. These options are mutually exclusive.
302+
303+
ALLOWOVERWRITE: By default, UNLOAD fails if it finds files that it would possibly overwrite.
304+
If ALLOWOVERWRITE is specified, UNLOAD overwrites existing files, including the manifest file.
305+
306+
CLEANPATH: Removes existing files located in the Amazon S3 path specified in the TO clause
307+
before unloading files to the specified location. If you include the PARTITION BY clause,
308+
existing files are removed only from the partition folders to receive new files generated
309+
by the UNLOAD operation. You must have the s3:DeleteObject permission on the Amazon S3 bucket.
310+
Files removed using CLEANPATH are permanently deleted and can't be recovered.
311+
312+
For more information, see: https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html
297313
boto3_session
298314
The default boto3 session will be used if **boto3_session** is ``None``.
299315
@@ -307,6 +323,15 @@ def unload_to_files(
307323
... con=con,
308324
... iam_role="arn:aws:iam::XXX:role/XXX"
309325
... )
326+
>>> # Using CLEANPATH instead of ALLOWOVERWRITE
327+
>>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con:
328+
... wr.redshift.unload_to_files(
329+
... sql="SELECT * FROM public.mytable",
330+
... path="s3://bucket/extracted_parquet_files/",
331+
... con=con,
332+
... iam_role="arn:aws:iam::XXX:role/XXX",
333+
... cleanpath=True
334+
... )
310335
311336
312337
"""
@@ -339,11 +364,13 @@ def unload_to_files(
339364
# Escape quotation marks in SQL
340365
sql = sql.replace("'", "''")
341366

367+
overwrite_str: str = "CLEANPATH" if cleanpath else "ALLOWOVERWRITE"
368+
342369
unload_sql = (
343370
f"UNLOAD ('{sql}')\n"
344371
f"TO '{path}'\n"
345372
f"{auth_str}"
346-
"ALLOWOVERWRITE\n"
373+
f"{overwrite_str}\n"
347374
f"{parallel_str}\n"
348375
f"FORMAT {format_str}\n"
349376
"ENCRYPTED"
@@ -376,6 +403,7 @@ def unload(
376403
chunked: bool | int = False,
377404
keep_files: bool = False,
378405
parallel: bool = True,
406+
cleanpath: bool = False,
379407
use_threads: bool | int = True,
380408
boto3_session: boto3.Session | None = None,
381409
s3_additional_kwargs: dict[str, str] | None = None,
@@ -452,6 +480,21 @@ def unload(
452480
By default, UNLOAD writes data in parallel to multiple files, according to the number of
453481
slices in the cluster. If parallel is False, UNLOAD writes to one or more data files serially,
454482
sorted absolutely according to the ORDER BY clause, if one is used.
483+
cleanpath
484+
Use CLEANPATH instead of ALLOWOVERWRITE. When True, uses CLEANPATH to remove existing files
485+
located in the Amazon S3 path before unloading files. When False (default), uses ALLOWOVERWRITE
486+
to overwrite existing files, including the manifest file. These options are mutually exclusive.
487+
488+
ALLOWOVERWRITE: By default, UNLOAD fails if it finds files that it would possibly overwrite.
489+
If ALLOWOVERWRITE is specified, UNLOAD overwrites existing files, including the manifest file.
490+
491+
CLEANPATH: Removes existing files located in the Amazon S3 path specified in the TO clause
492+
before unloading files to the specified location. If you include the PARTITION BY clause,
493+
existing files are removed only from the partition folders to receive new files generated
494+
by the UNLOAD operation. You must have the s3:DeleteObject permission on the Amazon S3 bucket.
495+
Files removed using CLEANPATH are permanently deleted and can't be recovered.
496+
497+
For more information, see: https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html
455498
dtype_backend
456499
Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays,
457500
nullable dtypes are used for all dtypes that have a nullable implementation when
@@ -489,6 +532,15 @@ def unload(
489532
... con=con,
490533
... iam_role="arn:aws:iam::XXX:role/XXX"
491534
... )
535+
>>> # Using CLEANPATH instead of ALLOWOVERWRITE
536+
>>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con:
537+
... df = wr.redshift.unload(
538+
... sql="SELECT * FROM public.mytable",
539+
... path="s3://bucket/extracted_parquet_files/",
540+
... con=con,
541+
... iam_role="arn:aws:iam::XXX:role/XXX",
542+
... cleanpath=True
543+
... )
492544
493545
"""
494546
path = path if path.endswith("/") else f"{path}/"
@@ -505,6 +557,7 @@ def unload(
505557
kms_key_id=kms_key_id,
506558
manifest=False,
507559
parallel=parallel,
560+
cleanpath=cleanpath,
508561
boto3_session=boto3_session,
509562
)
510563
if chunked is False:

tests/unit/test_redshift.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,31 @@ def test_unload_escape_quotation_marks(
14281428
assert len(df2) == 1
14291429

14301430

1431+
@pytest.mark.parametrize("cleanpath", [False, True])
1432+
def test_unload_cleanpath(
1433+
path: str,
1434+
redshift_table: str,
1435+
redshift_con: redshift_connector.Connection,
1436+
databases_parameters: dict[str, Any],
1437+
cleanpath: bool,
1438+
) -> None:
1439+
df = pd.DataFrame({"id": [1, 2], "name": ["foo", "bar"]})
1440+
schema = "public"
1441+
1442+
wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, mode="overwrite", index=False)
1443+
1444+
df2 = wr.redshift.unload(
1445+
sql=f"SELECT * FROM {schema}.{redshift_table}",
1446+
con=redshift_con,
1447+
iam_role=databases_parameters["redshift"]["role"],
1448+
path=path,
1449+
keep_files=False,
1450+
cleanpath=cleanpath,
1451+
)
1452+
assert len(df2.index) == 2
1453+
assert len(df2.columns) == 2
1454+
1455+
14311456
@pytest.mark.parametrize(
14321457
"mode,overwrite_method",
14331458
[

0 commit comments

Comments
 (0)