Skip to content

Commit d587118

Browse files
committed
Fixing bug for boolean type on spark.to_redshift()
1 parent 7f0e704 commit d587118

File tree

5 files changed

+73
-37
lines changed

5 files changed

+73
-37
lines changed

awswrangler/data_types.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212

1313
def athena2pandas(dtype: str) -> str:
1414
dtype = dtype.lower()
15-
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
15+
if dtype in ("int", "integer", "bigint", "smallint", "tinyint"):
1616
return "Int64"
17-
elif dtype in ["float", "double", "real"]:
17+
elif dtype in ("float", "double", "real"):
1818
return "float64"
1919
elif dtype == "boolean":
2020
return "bool"
21-
elif dtype in ["string", "char", "varchar"]:
21+
elif dtype in ("string", "char", "varchar"):
2222
return "str"
23-
elif dtype in ["timestamp", "timestamp with time zone"]:
23+
elif dtype in ("timestamp", "timestamp with time zone"):
2424
return "datetime64"
2525
elif dtype == "date":
2626
return "date"
@@ -36,17 +36,17 @@ def athena2pyarrow(dtype: str) -> str:
3636
return "int8"
3737
if dtype == "smallint":
3838
return "int16"
39-
elif dtype in ["int", "integer"]:
39+
elif dtype in ("int", "integer"):
4040
return "int32"
4141
elif dtype == "bigint":
4242
return "int64"
4343
elif dtype == "float":
4444
return "float32"
4545
elif dtype == "double":
4646
return "float64"
47-
elif dtype in ["boolean", "bool"]:
47+
elif dtype in ("boolean", "bool"):
4848
return "bool"
49-
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
49+
elif dtype in ("string", "char", "varchar", "array", "row", "map"):
5050
return "string"
5151
elif dtype == "timestamp":
5252
return "timestamp[ns]"
@@ -58,13 +58,13 @@ def athena2pyarrow(dtype: str) -> str:
5858

5959
def athena2python(dtype: str) -> Optional[type]:
6060
dtype = dtype.lower()
61-
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
61+
if dtype in ("int", "integer", "bigint", "smallint", "tinyint"):
6262
return int
63-
elif dtype in ["float", "double", "real"]:
63+
elif dtype in ("float", "double", "real"):
6464
return float
6565
elif dtype == "boolean":
6666
return bool
67-
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
67+
elif dtype in ("string", "char", "varchar", "array", "row", "map"):
6868
return str
6969
elif dtype == "timestamp":
7070
return datetime
@@ -80,17 +80,17 @@ def athena2redshift(dtype: str) -> str:
8080
dtype = dtype.lower()
8181
if dtype == "smallint":
8282
return "SMALLINT"
83-
elif dtype in ["int", "integer"]:
83+
elif dtype in ("int", "integer"):
8484
return "INTEGER"
8585
elif dtype == "bigint":
8686
return "BIGINT"
8787
elif dtype == "float":
8888
return "FLOAT4"
8989
elif dtype == "double":
9090
return "FLOAT8"
91-
elif dtype in ["boolean", "bool"]:
91+
elif dtype in ("boolean", "bool"):
9292
return "BOOL"
93-
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
93+
elif dtype in ("string", "char", "varchar", "array", "row", "map"):
9494
return "VARCHAR(256)"
9595
elif dtype == "timestamp":
9696
return "TIMESTAMP"
@@ -104,7 +104,7 @@ def pandas2athena(dtype: str) -> str:
104104
dtype = dtype.lower()
105105
if dtype == "int32":
106106
return "int"
107-
elif dtype in ["int64", "Int64"]:
107+
elif dtype in ("int64", "Int64"):
108108
return "bigint"
109109
elif dtype == "float32":
110110
return "float"
@@ -214,19 +214,19 @@ def python2athena(python_type: type) -> str:
214214

215215
def redshift2athena(dtype: str) -> str:
216216
dtype_str = str(dtype)
217-
if dtype_str in ["SMALLINT", "INT2"]:
217+
if dtype_str in ("SMALLINT", "INT2"):
218218
return "smallint"
219-
elif dtype_str in ["INTEGER", "INT", "INT4"]:
219+
elif dtype_str in ("INTEGER", "INT", "INT4"):
220220
return "int"
221-
elif dtype_str in ["BIGINT", "INT8"]:
221+
elif dtype_str in ("BIGINT", "INT8"):
222222
return "bigint"
223-
elif dtype_str in ["REAL", "FLOAT4"]:
223+
elif dtype_str in ("REAL", "FLOAT4"):
224224
return "float"
225-
elif dtype_str in ["DOUBLE PRECISION", "FLOAT8", "FLOAT"]:
225+
elif dtype_str in ("DOUBLE PRECISION", "FLOAT8", "FLOAT"):
226226
return "double"
227-
elif dtype_str in ["BOOLEAN", "BOOL"]:
227+
elif dtype_str in ("BOOLEAN", "BOOL"):
228228
return "boolean"
229-
elif dtype_str in ["VARCHAR", "CHARACTER VARYING", "NVARCHAR", "TEXT"]:
229+
elif dtype_str in ("VARCHAR", "CHARACTER VARYING", "NVARCHAR", "TEXT"):
230230
return "string"
231231
elif dtype_str == "DATE":
232232
return "date"
@@ -238,19 +238,19 @@ def redshift2athena(dtype: str) -> str:
238238

239239
def redshift2pyarrow(dtype: str) -> str:
240240
dtype_str: str = str(dtype)
241-
if dtype_str in ["SMALLINT", "INT2"]:
241+
if dtype_str in ("SMALLINT", "INT2"):
242242
return "int16"
243-
elif dtype_str in ["INTEGER", "INT", "INT4"]:
243+
elif dtype_str in ("INTEGER", "INT", "INT4"):
244244
return "int32"
245-
elif dtype_str in ["BIGINT", "INT8"]:
245+
elif dtype_str in ("BIGINT", "INT8"):
246246
return "int64"
247-
elif dtype_str in ["REAL", "FLOAT4"]:
247+
elif dtype_str in ("REAL", "FLOAT4"):
248248
return "float32"
249-
elif dtype_str in ["DOUBLE PRECISION", "FLOAT8", "FLOAT"]:
249+
elif dtype_str in ("DOUBLE PRECISION", "FLOAT8", "FLOAT"):
250250
return "float64"
251-
elif dtype_str in ["BOOLEAN", "BOOL"]:
251+
elif dtype_str in ("BOOLEAN", "BOOL"):
252252
return "bool"
253-
elif dtype_str in ["VARCHAR", "CHARACTER VARYING", "NVARCHAR", "TEXT"]:
253+
elif dtype_str in ("VARCHAR", "CHARACTER VARYING", "NVARCHAR", "TEXT"):
254254
return "string"
255255
elif dtype_str == "DATE":
256256
return "date32"
@@ -272,7 +272,7 @@ def spark2redshift(dtype: str) -> str:
272272
return "FLOAT4"
273273
elif dtype == "double":
274274
return "FLOAT8"
275-
elif dtype == "bool":
275+
elif dtype in ("bool", "boolean"):
276276
return "BOOLEAN"
277277
elif dtype == "timestamp":
278278
return "TIMESTAMP"
@@ -308,7 +308,7 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, preserve_index:
308308
"""
309309
cols = []
310310
cols_dtypes = {}
311-
if indexes_position not in ["right", "left"]:
311+
if indexes_position not in ("right", "left"):
312312
raise ValueError(f"indexes_position must be \"right\" or \"left\"")
313313

314314
# Handle exception data types (e.g. Int64)

requirements-dev.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ yapf~=0.28.0
22
mypy~=0.740
33
flake8~=3.7.9
44
pytest-cov~=2.8.1
5-
cfn-lint~=0.25.0
6-
twine~=2.0.0
5+
cfn-lint~=0.25.2
6+
twine~=3.0.0
77
wheel~=0.33.6
88
sphinx~=2.2.1
99
pyspark~=2.4.4

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy~=1.17.4
22
pandas~=0.25.3
33
pyarrow~=0.15.1
4-
botocore~=1.13.18
5-
boto3~=1.10.18
4+
botocore~=1.13.21
5+
boto3~=1.10.21
66
s3fs~=0.4.0
77
tenacity~=6.0.0
88
pg8000~=1.13.2

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
"numpy~=1.17.4",
2525
"pandas~=0.25.3",
2626
"pyarrow~=0.15.1",
27-
"botocore~=1.13.18",
28-
"boto3~=1.10.18",
27+
"botocore~=1.13.21",
28+
"boto3~=1.10.21",
2929
"s3fs~=0.4.0",
3030
"tenacity~=6.0.0",
3131
"pg8000~=1.13.2",

testing/test_awswrangler/test_redshift.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,43 @@ def test_to_redshift_spark_big(session, bucket, redshift_parameters):
280280
assert len(list(dataframe.columns)) == len(list(rows[0]))
281281

282282

283+
def test_to_redshift_spark_bool(session, bucket, redshift_parameters):
284+
dataframe = session.spark_session.createDataFrame(
285+
pd.DataFrame({
286+
"A": [1, 2, 3],
287+
"B": [True, False, True]
288+
})
289+
)
290+
print(dataframe)
291+
print(dataframe.dtypes)
292+
con = Redshift.generate_connection(
293+
database="test",
294+
host=redshift_parameters.get("RedshiftAddress"),
295+
port=redshift_parameters.get("RedshiftPort"),
296+
user="test",
297+
password=redshift_parameters.get("RedshiftPassword"),
298+
)
299+
session.spark.to_redshift(
300+
dataframe=dataframe,
301+
path=f"s3://{bucket}/redshift-load-bool/",
302+
connection=con,
303+
schema="public",
304+
table="test",
305+
iam_role=redshift_parameters.get("RedshiftRole"),
306+
mode="overwrite",
307+
min_num_partitions=1,
308+
)
309+
cursor = con.cursor()
310+
cursor.execute("SELECT * from public.test")
311+
rows = cursor.fetchall()
312+
cursor.close()
313+
con.close()
314+
assert dataframe.count() == len(rows)
315+
assert len(list(dataframe.columns)) == len(list(rows[0]))
316+
assert type(rows[0][0]) == int
317+
assert type(rows[0][1]) == bool
318+
319+
283320
def test_stress_to_redshift_spark_big(session, bucket, redshift_parameters):
284321
dataframe = session.spark_session.createDataFrame(
285322
pd.DataFrame({
@@ -288,7 +325,6 @@ def test_stress_to_redshift_spark_big(session, bucket, redshift_parameters):
288325
"C": list(range(1_000_000))
289326
}))
290327
dataframe.cache()
291-
292328
for i in range(10):
293329
print(f"Run number: {i}")
294330
con = Redshift.generate_connection(

0 commit comments

Comments
 (0)