in src/query/src/main/java/org/apache/kylin/query/optrule/OlapAggSumCastRule.java [85:171]
public void onMatch(RelOptRuleCall ruleCall) {
Map<Integer, AggregateCall> sumMatchMap = new HashMap<>();
Map<AggregateCall, AggregateCall> rewriteAggCallMap = new HashMap<>();
Aggregate oldAgg = ruleCall.rel(0);
Project oldProject = ruleCall.rel(1);
List<AggregateCall> aggCallList = oldAgg.getAggCallList();
boolean hasAggSum = false;
for (AggregateCall aggregateCall : aggCallList) {
if (SqlKind.SUM.name().equalsIgnoreCase(aggregateCall.getAggregation().getKind().name())) {
hasAggSum = true;
List<Integer> argList = aggregateCall.getArgList();
if (argList.size() == 1) {
sumMatchMap.put(argList.get(0), aggregateCall);
}
}
}
if (!hasAggSum)
return;
boolean isHasAggSumCastDouble = false;
RelBuilder relBuilder = ruleCall.builder();
RelDataTypeFactory typeFactory = relBuilder.getTypeFactory();
List<RexNode> bottomProjectRexNodes = new LinkedList<>();
List<RexNode> rewriteProjectRexNodes = new LinkedList<>();
List<RexNode> exprList = oldProject.getProjects();
Set<Integer> groupBySet = oldAgg.getGroupSet().asSet();
for (int i = 0; i < exprList.size(); i++) {
AggregateCall aggregateCall = sumMatchMap.get(i);
RexNode rexNode = exprList.get(i);
if (aggregateCall == null) {
bottomProjectRexNodes.add(rexNode);
continue;
}
RexNode curProjectExp = rexNode;
if (rexNode instanceof RexCall && ((RexCall) rexNode).op instanceof SqlCastFunction) {
RexCall rexCall = (RexCall) rexNode;
List<RexNode> opList = rexCall.getOperands();
if (opList.size() != 1) {
bottomProjectRexNodes.add(rexNode);
continue;
}
RexNode rexNodeOp = opList.get(0);
if (SqlTypeName.DOUBLE == rexCall.getType().getSqlTypeName()
&& SqlTypeFamily.NUMERIC == rexNodeOp.getType().getSqlTypeName().getFamily()) {
isHasAggSumCastDouble = true;
List<RexNode> operands = ((RexCall) curProjectExp).getOperands();
RexNode curRexNode = operands.get(0);
AggregateCall newAggCall;
RelDataType returnDataType = aggregateCall.getAggregation().inferReturnType(typeFactory,
Collections.singletonList(curRexNode.getType()));
if (groupBySet.contains(i)) {
newAggCall = new AggregateCall(aggregateCall.getAggregation(), false,
Arrays.asList(exprList.size() + rewriteProjectRexNodes.size()), returnDataType,
aggregateCall.getName());
rewriteProjectRexNodes.add(curRexNode);
} else {
newAggCall = new AggregateCall(aggregateCall.getAggregation(), false,
aggregateCall.getArgList(), returnDataType, aggregateCall.getName());
curProjectExp = curRexNode;
}
rewriteAggCallMap.put(aggregateCall, newAggCall);
}
}
bottomProjectRexNodes.add(curProjectExp);
}
if (!isHasAggSumCastDouble) {
return;
}
bottomProjectRexNodes.addAll(rewriteProjectRexNodes);
relBuilder.push(oldProject.getInput());
relBuilder.project(bottomProjectRexNodes);
List<AggregateCall> newAggregateCallList = new ArrayList<>(oldAgg.getAggCallList().size());
oldAgg.getAggCallList().forEach(aggCall -> {
AggregateCall newAggCall = rewriteAggCallMap.get(aggCall);
if (newAggCall == null) {
newAggCall = aggCall;
}
newAggregateCallList.add(newAggCall);
});
RelBuilder.GroupKey groupKey = oldAgg.getGroupSets() == null ? relBuilder.groupKey(oldAgg.getGroupSet())
: relBuilder.groupKey(oldAgg.getGroupSet(), oldAgg.getGroupSets());
relBuilder.aggregate(groupKey, newAggregateCallList);
List<RexNode> topProjList = buildTopProject(relBuilder, oldAgg, rewriteAggCallMap);
relBuilder.project(topProjList);
ruleCall.transformTo(relBuilder.build());
}