Skip to content

Commit eb19a67

Browse files
ygf11alamb
andauthored
Rewrite coerce_plan_expr_for_schema to fix union type coercion (apache#4862)
* Rewrite coerce_plan_expr_for_schema * add integration tests * Update datafusion/expr/src/expr_rewriter.rs Co-authored-by: Andrew Lamb <[email protected]> * fix comment * fix tests * fix tests Co-authored-by: Andrew Lamb <[email protected]>
1 parent 3b86643 commit eb19a67

File tree

5 files changed

+225
-19
lines changed

5 files changed

+225
-19
lines changed

datafusion/core/tests/sql/mod.rs

+19
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,25 @@ fn create_sort_merge_join_datatype_context() -> Result<SessionContext> {
692692
Ok(ctx)
693693
}
694694

695+
fn create_union_context() -> Result<SessionContext> {
696+
let ctx = SessionContext::new();
697+
let t1_schema = Arc::new(Schema::new(vec![
698+
Field::new("id", DataType::Int32, true),
699+
Field::new("name", DataType::UInt8, true),
700+
]));
701+
let t1_data = RecordBatch::new_empty(t1_schema);
702+
ctx.register_batch("t1", t1_data)?;
703+
704+
let t2_schema = Arc::new(Schema::new(vec![
705+
Field::new("id", DataType::UInt8, true),
706+
Field::new("name", DataType::UInt8, true),
707+
]));
708+
let t2_data = RecordBatch::new_empty(t2_schema);
709+
ctx.register_batch("t2", t2_data)?;
710+
711+
Ok(ctx)
712+
}
713+
695714
fn get_tpch_table_schema(table: &str) -> Schema {
696715
match table {
697716
"customer" => Schema::new(vec![

datafusion/core/tests/sql/union.rs

+82
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,85 @@ async fn union_schemas() -> Result<()> {
140140
assert_batches_eq!(expected, &result);
141141
Ok(())
142142
}
143+
144+
#[tokio::test]
145+
async fn union_with_except_input() -> Result<()> {
146+
let ctx = create_union_context()?;
147+
let sql = "(
148+
SELECT name FROM t1
149+
EXCEPT
150+
SELECT name FROM t2
151+
)
152+
UNION ALL
153+
(
154+
SELECT name FROM t2
155+
EXCEPT
156+
SELECT name FROM t1
157+
)";
158+
let msg = format!("Creating logical plan for '{sql}'");
159+
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
160+
let plan = dataframe.into_optimized_plan()?;
161+
let expected = vec![
162+
"Explain [plan_type:Utf8, plan:Utf8]",
163+
" Union [name:UInt8;N]",
164+
" LeftAnti Join: t1.name = t2.name [name:UInt8;N]",
165+
" Distinct: [name:UInt8;N]",
166+
" TableScan: t1 projection=[name] [name:UInt8;N]",
167+
" Projection: t2.name [name:UInt8;N]",
168+
" TableScan: t2 projection=[name] [name:UInt8;N]",
169+
" LeftAnti Join: t2.name = t1.name [name:UInt8;N]",
170+
" Distinct: [name:UInt8;N]",
171+
" TableScan: t2 projection=[name] [name:UInt8;N]",
172+
" Projection: t1.name [name:UInt8;N]",
173+
" TableScan: t1 projection=[name] [name:UInt8;N]",
174+
];
175+
176+
let formatted = plan.display_indent_schema().to_string();
177+
let actual: Vec<&str> = formatted.trim().lines().collect();
178+
assert_eq!(
179+
expected, actual,
180+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
181+
);
182+
Ok(())
183+
}
184+
185+
#[tokio::test]
186+
async fn union_with_type_coercion() -> Result<()> {
187+
let ctx = create_union_context()?;
188+
let sql = "(
189+
SELECT id, name FROM t1
190+
EXCEPT
191+
SELECT id, name FROM t2
192+
)
193+
UNION ALL
194+
(
195+
SELECT id, name FROM t2
196+
EXCEPT
197+
SELECT id, name FROM t1
198+
)";
199+
let msg = format!("Creating logical plan for '{sql}'");
200+
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
201+
let plan = dataframe.into_optimized_plan()?;
202+
let expected = vec![
203+
"Explain [plan_type:Utf8, plan:Utf8]",
204+
" Union [id:Int32;N, name:UInt8;N]",
205+
" LeftAnti Join: t1.id = CAST(t2.id AS Int32), t1.name = t2.name [id:Int32;N, name:UInt8;N]",
206+
" Distinct: [id:Int32;N, name:UInt8;N]",
207+
" TableScan: t1 projection=[id, name] [id:Int32;N, name:UInt8;N]",
208+
" Projection: t2.id, t2.name [id:UInt8;N, name:UInt8;N]",
209+
" TableScan: t2 projection=[id, name] [id:UInt8;N, name:UInt8;N]",
210+
" Projection: CAST(t2.id AS Int32) AS id, t2.name [id:Int32;N, name:UInt8;N]",
211+
" LeftAnti Join: CAST(t2.id AS Int32) = t1.id, t2.name = t1.name [id:UInt8;N, name:UInt8;N]",
212+
" Distinct: [id:UInt8;N, name:UInt8;N]",
213+
" TableScan: t2 projection=[id, name] [id:UInt8;N, name:UInt8;N]",
214+
" Projection: t1.id, t1.name [id:Int32;N, name:UInt8;N]",
215+
" TableScan: t1 projection=[id, name] [id:Int32;N, name:UInt8;N]",
216+
];
217+
let formatted = plan.display_indent_schema().to_string();
218+
let actual: Vec<&str> = formatted.trim().lines().collect();
219+
assert_eq!(
220+
expected, actual,
221+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
222+
);
223+
Ok(())
224+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
##########
19+
## UNION Tests
20+
##########
21+
22+
statement ok
23+
CREATE TABLE t1(
24+
id INT,
25+
name TEXT,
26+
) as VALUES
27+
(1, 'Alex'),
28+
(2, 'Bob'),
29+
(3, 'Alice')
30+
;
31+
32+
statement ok
33+
CREATE TABLE t2(
34+
id TINYINT,
35+
name TEXT,
36+
) as VALUES
37+
(1, 'Alex'),
38+
(2, 'Bob'),
39+
(3, 'John')
40+
;
41+
42+
# union with EXCEPT(JOIN)
43+
query T
44+
(
45+
SELECT name FROM t1
46+
EXCEPT
47+
SELECT name FROM t2
48+
)
49+
UNION ALL
50+
(
51+
SELECT name FROM t2
52+
EXCEPT
53+
SELECT name FROM t1
54+
)
55+
ORDER BY name
56+
----
57+
Alice
58+
John
59+
60+
61+
62+
# union with type coercion
63+
query T
64+
(
65+
SELECT * FROM t1
66+
EXCEPT
67+
SELECT * FROM t2
68+
)
69+
UNION ALL
70+
(
71+
SELECT * FROM t2
72+
EXCEPT
73+
SELECT * FROM t1
74+
)
75+
ORDER BY name
76+
----
77+
3 Alice
78+
3 John

datafusion/expr/src/expr_rewriter.rs

+45-18
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::expr::{
2222
Like, Sort, TryCast, WindowFunction,
2323
};
2424
use crate::logical_plan::{Aggregate, Projection};
25-
use crate::utils::{from_plan, grouping_set_to_exprlist};
25+
use crate::utils::grouping_set_to_exprlist;
2626
use crate::{Expr, ExprSchemable, LogicalPlan};
2727
use datafusion_common::Result;
2828
use datafusion_common::{Column, DFSchema};
@@ -525,29 +525,56 @@ pub fn coerce_plan_expr_for_schema(
525525
plan: &LogicalPlan,
526526
schema: &DFSchema,
527527
) -> Result<LogicalPlan> {
528-
let new_expr = plan
529-
.expressions()
528+
match plan {
529+
// special case Projection to avoid adding multiple projections
530+
LogicalPlan::Projection(Projection { expr, input, .. }) => {
531+
let new_exprs =
532+
coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?;
533+
let projection = Projection::try_new(new_exprs, input.clone())?;
534+
Ok(LogicalPlan::Projection(projection))
535+
}
536+
_ => {
537+
let exprs: Vec<Expr> = plan
538+
.schema()
539+
.fields()
540+
.iter()
541+
.map(|field| Expr::Column(field.qualified_column()))
542+
.collect();
543+
544+
let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
545+
let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err());
546+
if add_project {
547+
let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?;
548+
Ok(LogicalPlan::Projection(projection))
549+
} else {
550+
Ok(plan.clone())
551+
}
552+
}
553+
}
554+
}
555+
556+
fn coerce_exprs_for_schema(
557+
exprs: Vec<Expr>,
558+
src_schema: &DFSchema,
559+
dst_schema: &DFSchema,
560+
) -> Result<Vec<Expr>> {
561+
exprs
530562
.into_iter()
531563
.enumerate()
532-
.map(|(i, expr)| {
533-
let new_type = schema.field(i).data_type();
534-
if plan.schema().field(i).data_type() != schema.field(i).data_type() {
535-
match (plan, &expr) {
536-
(
537-
LogicalPlan::Projection(Projection { input, .. }),
538-
Expr::Alias(e, alias),
539-
) => Ok(e.clone().cast_to(new_type, input.schema())?.alias(alias)),
540-
_ => expr.cast_to(new_type, plan.schema()),
564+
.map(|(idx, expr)| {
565+
let new_type = dst_schema.field(idx).data_type();
566+
if new_type != &expr.get_type(src_schema)? {
567+
match expr {
568+
Expr::Alias(e, alias) => {
569+
Ok(e.cast_to(new_type, src_schema)?.alias(alias))
570+
}
571+
_ => expr.cast_to(new_type, src_schema),
541572
}
542573
} else {
543-
Ok(expr)
574+
Ok(expr.clone())
544575
}
545576
})
546-
.collect::<Result<Vec<_>>>()?;
547-
548-
let new_inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
549-
550-
from_plan(plan, &new_expr, &new_inputs)
577+
.collect::<Result<_>>()
551578
}
552579

553580
#[cfg(test)]

datafusion/expr/src/logical_plan/builder.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalP
10161016
match plan {
10171017
LogicalPlan::Projection(Projection { expr, input, .. }) => {
10181018
Ok(Arc::new(project_with_column_index(
1019-
expr.to_vec(),
1019+
expr,
10201020
input,
10211021
Arc::new(union_schema.clone()),
10221022
)?))

0 commit comments

Comments
 (0)