public void onMatch()

in src/query/src/main/java/org/apache/kylin/query/optrule/OlapAggSumCastRule.java [85:171]


    public void onMatch(RelOptRuleCall ruleCall) {
        Map<Integer, AggregateCall> sumMatchMap = new HashMap<>();
        Map<AggregateCall, AggregateCall> rewriteAggCallMap = new HashMap<>();
        Aggregate oldAgg = ruleCall.rel(0);
        Project oldProject = ruleCall.rel(1);
        List<AggregateCall> aggCallList = oldAgg.getAggCallList();
        boolean hasAggSum = false;
        for (AggregateCall aggregateCall : aggCallList) {
            if (SqlKind.SUM.name().equalsIgnoreCase(aggregateCall.getAggregation().getKind().name())) {
                hasAggSum = true;
                List<Integer> argList = aggregateCall.getArgList();
                if (argList.size() == 1) {
                    sumMatchMap.put(argList.get(0), aggregateCall);
                }
            }
        }
        if (!hasAggSum)
            return;
        boolean isHasAggSumCastDouble = false;
        RelBuilder relBuilder = ruleCall.builder();
        RelDataTypeFactory typeFactory = relBuilder.getTypeFactory();
        List<RexNode> bottomProjectRexNodes = new LinkedList<>();
        List<RexNode> rewriteProjectRexNodes = new LinkedList<>();
        List<RexNode> exprList = oldProject.getProjects();
        Set<Integer> groupBySet = oldAgg.getGroupSet().asSet();
        for (int i = 0; i < exprList.size(); i++) {
            AggregateCall aggregateCall = sumMatchMap.get(i);
            RexNode rexNode = exprList.get(i);
            if (aggregateCall == null) {
                bottomProjectRexNodes.add(rexNode);
                continue;
            }
            RexNode curProjectExp = rexNode;
            if (rexNode instanceof RexCall && ((RexCall) rexNode).op instanceof SqlCastFunction) {
                RexCall rexCall = (RexCall) rexNode;
                List<RexNode> opList = rexCall.getOperands();
                if (opList.size() != 1) {
                    bottomProjectRexNodes.add(rexNode);
                    continue;
                }
                RexNode rexNodeOp = opList.get(0);
                if (SqlTypeName.DOUBLE == rexCall.getType().getSqlTypeName()
                        && SqlTypeFamily.NUMERIC == rexNodeOp.getType().getSqlTypeName().getFamily()) {
                    isHasAggSumCastDouble = true;
                    List<RexNode> operands = ((RexCall) curProjectExp).getOperands();
                    RexNode curRexNode = operands.get(0);
                    AggregateCall newAggCall;
                    RelDataType returnDataType = aggregateCall.getAggregation().inferReturnType(typeFactory,
                            Collections.singletonList(curRexNode.getType()));
                    if (groupBySet.contains(i)) {
                        newAggCall = new AggregateCall(aggregateCall.getAggregation(), false,
                                Arrays.asList(exprList.size() + rewriteProjectRexNodes.size()), returnDataType,
                                aggregateCall.getName());
                        rewriteProjectRexNodes.add(curRexNode);
                    } else {
                        newAggCall = new AggregateCall(aggregateCall.getAggregation(), false,
                                aggregateCall.getArgList(), returnDataType, aggregateCall.getName());
                        curProjectExp = curRexNode;
                    }
                    rewriteAggCallMap.put(aggregateCall, newAggCall);
                }
            }
            bottomProjectRexNodes.add(curProjectExp);
        }

        if (!isHasAggSumCastDouble) {
            return;
        }

        bottomProjectRexNodes.addAll(rewriteProjectRexNodes);
        relBuilder.push(oldProject.getInput());
        relBuilder.project(bottomProjectRexNodes);
        List<AggregateCall> newAggregateCallList = new ArrayList<>(oldAgg.getAggCallList().size());
        oldAgg.getAggCallList().forEach(aggCall -> {
            AggregateCall newAggCall = rewriteAggCallMap.get(aggCall);
            if (newAggCall == null) {
                newAggCall = aggCall;
            }
            newAggregateCallList.add(newAggCall);
        });
        RelBuilder.GroupKey groupKey = oldAgg.getGroupSets() == null ? relBuilder.groupKey(oldAgg.getGroupSet())
                : relBuilder.groupKey(oldAgg.getGroupSet(), oldAgg.getGroupSets());
        relBuilder.aggregate(groupKey, newAggregateCallList);
        List<RexNode> topProjList = buildTopProject(relBuilder, oldAgg, rewriteAggCallMap);
        relBuilder.project(topProjList);
        ruleCall.transformTo(relBuilder.build());
    }