Skip to content

Commit 8771e3f

Browse files
committed
[CALCITE-6555] RelBuilder.aggregateRex wrongly thinks aggregate functions of "GROUP BY ()" queries are NOT NULL
In RelBuilder, the aggregateRex method (added in CALCITE-5802) wrongly thinks that aggregate functions in a `GROUP BY ()` query are NOT NULL. Consider the query SELECT SUM(empno) AS s, COUNT(empno) AS c FROM emp GROUP BY () `SUM(empno)` should be nullable, even though `empno` has type `SMALLINT NOT NULL`, because `GROUP BY ()` will return one row even if `emp` has no rows, and therefore `SUM` will be evaluated over the empty set. A RelBuilder test that attempts to build an equivalent query gets the following error stack: java.lang.AssertionError: type mismatch: ref: SMALLINT NOT NULL input: SMALLINT We add a test case for measure queries, because measures are the only code path that uses `aggregateRex` at present.
1 parent 30304bb commit 8771e3f

File tree

3 files changed

+119
-30
lines changed

3 files changed

+119
-30
lines changed

core/src/main/java/org/apache/calcite/tools/RelBuilder.java

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
import java.util.Set;
151151
import java.util.SortedSet;
152152
import java.util.TreeSet;
153+
import java.util.concurrent.atomic.AtomicInteger;
153154
import java.util.function.BiFunction;
154155
import java.util.function.Consumer;
155156
import java.util.function.Function;
@@ -2632,14 +2633,22 @@ public RelBuilder aggregateRex(GroupKey groupKey, boolean projectKey,
26322633
Iterable<? extends RexNode> nodes) {
26332634
final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey;
26342635
final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes);
2635-
for (RexNode node : nodes) {
2636-
aggBuilder.add(node);
2636+
2637+
// First pass. Call convert on each expression to ensure that aggCalls
2638+
// gets populated.
2639+
aggBuilder.registerExpressions(nodes);
2640+
2641+
// Create the Aggregate on the stack.
2642+
aggregate(groupKey, aggBuilder.aggCalls);
2643+
2644+
// Second pass. Call convert on each expression so that it references the
2645+
// actual aggCalls in the Aggregate that was just pushed onto the stack.
2646+
final List<RexNode> projects = new ArrayList<>();
2647+
if (projectKey) {
2648+
projects.addAll(fields(Util.range(groupKey.groupKeyCount())));
26372649
}
2638-
return aggregate(groupKey, aggBuilder.aggCalls)
2639-
.project(
2640-
Iterables.concat(
2641-
fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)),
2642-
aggBuilder.postProjects));
2650+
aggBuilder.convertExpressions(projects::add, nodes);
2651+
return project(projects);
26432652
}
26442653

26452654
/** Finishes the implementation of {@link #aggregate} by creating an
@@ -5040,46 +5049,82 @@ default boolean removeRedundantDistinct() {
50405049
/** Working state for {@link #aggregateRex}. */
50415050
private class AggBuilder {
50425051
final ImmutableList<RexNode> groupKeys;
5043-
final List<RexNode> postProjects = new ArrayList<>();
50445052
final List<AggCall> aggCalls = new ArrayList<>();
50455053

50465054
private AggBuilder(ImmutableList<RexNode> groupKeys) {
50475055
this.groupKeys = groupKeys;
50485056
}
50495057

5050-
/** Adds a node that may or may not contain an aggregate function. */
5051-
void add(RexNode node) {
5052-
postProjects.add(convert(node));
5053-
}
5054-
50555058
/** Adds a node that we know to contain an aggregate function, and returns
50565059
* an expression whose input row type is the output row type of the
50575060
* aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */
5058-
private RexNode convert(RexNode node) {
5059-
final RexBuilder rexBuilder = cluster.getRexBuilder();
5060-
if (node instanceof RexCall) {
5061-
final RexCall call = (RexCall) node;
5062-
if (call.getOperator().isAggregator()) {
5063-
final AggCall aggCall =
5064-
aggregateCall((SqlAggFunction) call.op, call.operands);
5065-
final int i = groupKeys.size() + aggCalls.size();
5066-
aggCalls.add(aggCall);
5067-
return rexBuilder.makeInputRef(call.getType(), i);
5061+
private RexNode convert(RegisterAgg registrar, RexNode node,
5062+
@Nullable String name) {
5063+
switch (node.getKind()) {
5064+
case AS:
5065+
final ImmutableList<RexNode> asOperands = ((RexCall) node).operands;
5066+
final String name2;
5067+
if (name != null) {
5068+
name2 = name;
50685069
} else {
5069-
final List<RexNode> operands = new ArrayList<>();
5070-
call.operands.forEach(operand ->
5071-
operands.add(convert(operand)));
5072-
return call.clone(call.type, operands);
5070+
final RexLiteral literal = (RexLiteral) asOperands.get(1);
5071+
name2 = requireNonNull(literal.getValueAs(String.class));
50735072
}
5074-
} else if (node instanceof RexInputRef) {
5073+
final RexNode node2 = convert(registrar, asOperands.get(0), name2);
5074+
return alias(node2, name2);
5075+
5076+
case INPUT_REF:
50755077
final int j = groupKeys.indexOf(node);
50765078
if (j < 0) {
50775079
throw new IllegalArgumentException("not a group key: " + node);
50785080
}
5079-
return rexBuilder.makeInputRef(node.getType(), j);
5080-
} else {
5081+
return field(j);
5082+
5083+
default:
5084+
if (node instanceof RexCall) {
5085+
final RexCall call = (RexCall) node;
5086+
if (call.getOperator().isAggregator()) {
5087+
// return a reference to the i'th agg call
5088+
return registrar.registerAgg((SqlAggFunction) call.op,
5089+
call.operands, call.type, name);
5090+
} else {
5091+
return call.clone(call.type,
5092+
Util.transform(call.operands, operand ->
5093+
convert(registrar, operand, null)));
5094+
}
5095+
}
50815096
return node;
50825097
}
50835098
}
5099+
5100+
void registerExpressions(Iterable<? extends RexNode> nodes) {
5101+
for (RexNode node : nodes) {
5102+
convert(this::registerAgg, node, null);
5103+
}
5104+
}
5105+
5106+
RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
5107+
RelDataType type, @Nullable String name) {
5108+
final int i = groupKeys.size() + aggCalls.size();
5109+
aggCalls.add(aggregateCall(op, operands).as(name));
5110+
return getRexBuilder().makeInputRef(type, i);
5111+
}
5112+
5113+
void convertExpressions(Consumer<RexNode> projects,
5114+
Iterable<? extends RexNode> nodes) {
5115+
final AtomicInteger j = new AtomicInteger(groupKeys.size());
5116+
for (RexNode node : nodes) {
5117+
projects.accept(
5118+
convert((op, operands, type, name) -> field(j.getAndIncrement()),
5119+
node, null));
5120+
}
5121+
}
5122+
}
5123+
5124+
/** Callback to handle creation of an aggregate call in
5125+
* {@link AggBuilder#convert}. */
5126+
private interface RegisterAgg {
5127+
RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands,
5128+
RelDataType type, @Nullable String name);
50845129
}
50855130
}

core/src/test/java/org/apache/calcite/test/RelBuilderTest.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3398,6 +3398,29 @@ private static RelBuilder assertSize(RelBuilder b,
33983398
assertThat(r3.getRowType().getFullTypeString(), is(expectedRowType));
33993399
}
34003400

3401+
/** Tests {@link RelBuilder#aggregateRex} with an aggregate call that needs to
3402+
* become nullable because of "GROUP BY ()". */
3403+
@Test void testAggregateRex4() {
3404+
// SELECT SUM(sal) AS s, COUNT(sal) AS c
3405+
// FROM emp
3406+
// GROUP BY ()
3407+
Function<RelBuilder, RelNode> f = b ->
3408+
b.scan("EMP")
3409+
.aggregateRex(b.groupKey(),
3410+
b.alias(b.call(SqlStdOperatorTable.SUM, b.field("EMPNO")), "s"),
3411+
b.alias(b.call(SqlStdOperatorTable.COUNT, b.field("SAL")), "c"))
3412+
.build();
3413+
final String expected =
3414+
"LogicalAggregate(group=[{}], s=[SUM($0)], c=[COUNT($5)])\n"
3415+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
3416+
// s is nullable because "GROUP BY ()" may have a group that contains 0 rows
3417+
final String expectedRowType =
3418+
"RecordType(SMALLINT s, BIGINT NOT NULL c) NOT NULL";
3419+
final RelNode r = f.apply(createBuilder());
3420+
assertThat(r, hasTree(expected));
3421+
assertThat(r.getRowType().getFullTypeString(), is(expectedRowType));
3422+
}
3423+
34013424
/** Tests that a projection retains field names after a join. */
34023425
@Test void testProjectJoin() {
34033426
final RelBuilder builder = RelBuilder.create(config().build());

core/src/test/resources/sql/measure.iq

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,27 @@ group by job;
8484

8585
!ok
8686

87+
# Measure on primary key gives type error (casting away NOT NULL); cause was
88+
# [CALCITE-6555] RelBuilder.aggregateRex thinks aggregate functions of
89+
# "GROUP BY ()" queries are NOT NULL
90+
with empm as (
91+
select *, min(empno) as measure avg_sal
92+
from emp
93+
)
94+
select deptno, avg_sal as a
95+
from empm
96+
group by deptno;
97+
+--------+------+
98+
| DEPTNO | A |
99+
+--------+------+
100+
| 10 | 7782 |
101+
| 20 | 7369 |
102+
| 30 | 7499 |
103+
+--------+------+
104+
(3 rows)
105+
106+
!ok
107+
87108
# Equivalent using AGGREGATE
88109
select job, aggregate(avg_sal) as a
89110
from empm

0 commit comments

Comments
 (0)