|
150 | 150 | import java.util.Set;
|
151 | 151 | import java.util.SortedSet;
|
152 | 152 | import java.util.TreeSet;
|
| 153 | +import java.util.concurrent.atomic.AtomicInteger; |
153 | 154 | import java.util.function.BiFunction;
|
154 | 155 | import java.util.function.Consumer;
|
155 | 156 | import java.util.function.Function;
|
@@ -2632,14 +2633,22 @@ public RelBuilder aggregateRex(GroupKey groupKey, boolean projectKey,
|
2632 | 2633 | Iterable<? extends RexNode> nodes) {
|
2633 | 2634 | final GroupKeyImpl groupKeyImpl = (GroupKeyImpl) groupKey;
|
2634 | 2635 | final AggBuilder aggBuilder = new AggBuilder(groupKeyImpl.nodes);
|
2635 |
| - for (RexNode node : nodes) { |
2636 |
| - aggBuilder.add(node); |
| 2636 | + |
| 2637 | + // First pass. Call convert on each expression to ensure that aggCalls |
| 2638 | + // gets populated. |
| 2639 | + aggBuilder.registerExpressions(nodes); |
| 2640 | + |
| 2641 | + // Create the Aggregate on the stack. |
| 2642 | + aggregate(groupKey, aggBuilder.aggCalls); |
| 2643 | + |
| 2644 | + // Second pass. Call convert on each expression so that it references the |
| 2645 | + // actual aggCalls in the Aggregate that was just pushed onto the stack. |
| 2646 | + final List<RexNode> projects = new ArrayList<>(); |
| 2647 | + if (projectKey) { |
| 2648 | + projects.addAll(fields(Util.range(groupKey.groupKeyCount()))); |
2637 | 2649 | }
|
2638 |
| - return aggregate(groupKey, aggBuilder.aggCalls) |
2639 |
| - .project( |
2640 |
| - Iterables.concat( |
2641 |
| - fields(Util.range(projectKey ? groupKey.groupKeyCount() : 0)), |
2642 |
| - aggBuilder.postProjects)); |
| 2650 | + aggBuilder.convertExpressions(projects::add, nodes); |
| 2651 | + return project(projects); |
2643 | 2652 | }
|
2644 | 2653 |
|
2645 | 2654 | /** Finishes the implementation of {@link #aggregate} by creating an
|
@@ -5040,46 +5049,82 @@ default boolean removeRedundantDistinct() {
|
5040 | 5049 | /** Working state for {@link #aggregateRex}. */
|
5041 | 5050 | private class AggBuilder {
|
5042 | 5051 | final ImmutableList<RexNode> groupKeys;
|
5043 |
| - final List<RexNode> postProjects = new ArrayList<>(); |
5044 | 5052 | final List<AggCall> aggCalls = new ArrayList<>();
|
5045 | 5053 |
|
5046 | 5054 | private AggBuilder(ImmutableList<RexNode> groupKeys) {
|
5047 | 5055 | this.groupKeys = groupKeys;
|
5048 | 5056 | }
|
5049 | 5057 |
|
5050 |
| - /** Adds a node that may or may not contain an aggregate function. */ |
5051 |
| - void add(RexNode node) { |
5052 |
| - postProjects.add(convert(node)); |
5053 |
| - } |
5054 |
| - |
5055 | 5058 | /** Adds a node that we know to contain an aggregate function, and returns
|
5056 | 5059 | * an expression whose input row type is the output row type of the
|
5057 | 5060 | * aggregate layer ({@link #groupKeys} and {@link #aggCalls}). */
|
5058 |
| - private RexNode convert(RexNode node) { |
5059 |
| - final RexBuilder rexBuilder = cluster.getRexBuilder(); |
5060 |
| - if (node instanceof RexCall) { |
5061 |
| - final RexCall call = (RexCall) node; |
5062 |
| - if (call.getOperator().isAggregator()) { |
5063 |
| - final AggCall aggCall = |
5064 |
| - aggregateCall((SqlAggFunction) call.op, call.operands); |
5065 |
| - final int i = groupKeys.size() + aggCalls.size(); |
5066 |
| - aggCalls.add(aggCall); |
5067 |
| - return rexBuilder.makeInputRef(call.getType(), i); |
| 5061 | + private RexNode convert(RegisterAgg registrar, RexNode node, |
| 5062 | + @Nullable String name) { |
| 5063 | + switch (node.getKind()) { |
| 5064 | + case AS: |
| 5065 | + final ImmutableList<RexNode> asOperands = ((RexCall) node).operands; |
| 5066 | + final String name2; |
| 5067 | + if (name != null) { |
| 5068 | + name2 = name; |
5068 | 5069 | } else {
|
5069 |
| - final List<RexNode> operands = new ArrayList<>(); |
5070 |
| - call.operands.forEach(operand -> |
5071 |
| - operands.add(convert(operand))); |
5072 |
| - return call.clone(call.type, operands); |
| 5070 | + final RexLiteral literal = (RexLiteral) asOperands.get(1); |
| 5071 | + name2 = requireNonNull(literal.getValueAs(String.class)); |
5073 | 5072 | }
|
5074 |
| - } else if (node instanceof RexInputRef) { |
| 5073 | + final RexNode node2 = convert(registrar, asOperands.get(0), name2); |
| 5074 | + return alias(node2, name2); |
| 5075 | + |
| 5076 | + case INPUT_REF: |
5075 | 5077 | final int j = groupKeys.indexOf(node);
|
5076 | 5078 | if (j < 0) {
|
5077 | 5079 | throw new IllegalArgumentException("not a group key: " + node);
|
5078 | 5080 | }
|
5079 |
| - return rexBuilder.makeInputRef(node.getType(), j); |
5080 |
| - } else { |
| 5081 | + return field(j); |
| 5082 | + |
| 5083 | + default: |
| 5084 | + if (node instanceof RexCall) { |
| 5085 | + final RexCall call = (RexCall) node; |
| 5086 | + if (call.getOperator().isAggregator()) { |
| 5087 | + // return a reference to the i'th agg call |
| 5088 | + return registrar.registerAgg((SqlAggFunction) call.op, |
| 5089 | + call.operands, call.type, name); |
| 5090 | + } else { |
| 5091 | + return call.clone(call.type, |
| 5092 | + Util.transform(call.operands, operand -> |
| 5093 | + convert(registrar, operand, null))); |
| 5094 | + } |
| 5095 | + } |
5081 | 5096 | return node;
|
5082 | 5097 | }
|
5083 | 5098 | }
|
| 5099 | + |
| 5100 | + void registerExpressions(Iterable<? extends RexNode> nodes) { |
| 5101 | + for (RexNode node : nodes) { |
| 5102 | + convert(this::registerAgg, node, null); |
| 5103 | + } |
| 5104 | + } |
| 5105 | + |
| 5106 | + RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands, |
| 5107 | + RelDataType type, @Nullable String name) { |
| 5108 | + final int i = groupKeys.size() + aggCalls.size(); |
| 5109 | + aggCalls.add(aggregateCall(op, operands).as(name)); |
| 5110 | + return getRexBuilder().makeInputRef(type, i); |
| 5111 | + } |
| 5112 | + |
| 5113 | + void convertExpressions(Consumer<RexNode> projects, |
| 5114 | + Iterable<? extends RexNode> nodes) { |
| 5115 | + final AtomicInteger j = new AtomicInteger(groupKeys.size()); |
| 5116 | + for (RexNode node : nodes) { |
| 5117 | + projects.accept( |
| 5118 | + convert((op, operands, type, name) -> field(j.getAndIncrement()), |
| 5119 | + node, null)); |
| 5120 | + } |
| 5121 | + } |
| 5122 | + } |
| 5123 | + |
| 5124 | + /** Callback to handle creation of an aggregate call in |
| 5125 | + * {@link AggBuilder#convert}. */ |
| 5126 | + private interface RegisterAgg { |
| 5127 | + RexInputRef registerAgg(SqlAggFunction op, List<RexNode> operands, |
| 5128 | + RelDataType type, @Nullable String name); |
5084 | 5129 | }
|
5085 | 5130 | }
|
0 commit comments