private RexNode reduceStddev()

in exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java [490:662]


  private RexNode reduceStddev(
      Aggregate oldAggRel,
      AggregateCall oldCall,
      boolean biased,
      boolean sqrt,
      List<AggregateCall> newCalls,
      Map<AggregateCall, RexNode> aggCallMapping,
      List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / count(x),
    //     .5)
    //
    // stddev_samp(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / nullif(count(x) - 1, 0),
    //     .5)
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory =
        oldAggRel.getCluster().getTypeFactory();
    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();

    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argType =
        getFieldType(
            oldAggRel.getInput(),
            argOrdinal);

    // final RexNode argRef = inputExprs.get(argOrdinal);
    RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
    inputExprs.set(argOrdinal, argRef);

    final RexNode argSquared =
        rexBuilder.makeCall(
            SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);

    RelDataType sumType =
        TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
            ImmutableList.of())
          .inferReturnType(oldCall.createBinding(oldAggRel));
    sumType = typeFactory.createTypeWithNullability(sumType, true);
    final AggregateCall sumArgSquaredAggCall =
        AggregateCall.create(
            new DrillCalciteSqlAggFunctionWrapper(
                new SqlSumAggFunction(sumType), sumType),
            oldCall.isDistinct(),
            oldCall.isApproximate(),
            oldCall.ignoreNulls(),
            ImmutableIntList.of(argSquaredOrdinal),
            oldCall.filterArg,
            oldCall.distinctKeys,
            oldCall.getCollation(),
            sumType,
            null);
    final RexNode sumArgSquared =
        rexBuilder.addAggCall(
            sumArgSquaredAggCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(argType));

    final AggregateCall sumArgAggCall =
        AggregateCall.create(
            new DrillCalciteSqlAggFunctionWrapper(
                new SqlSumAggFunction(sumType), sumType),
            oldCall.isDistinct(),
            oldCall.isApproximate(),
            oldCall.ignoreNulls(),
            ImmutableIntList.of(argOrdinal),
            oldCall.filterArg,
            oldCall.distinctKeys,
            oldCall.getCollation(),
            sumType,
            null);
    final RexNode sumArg =
          rexBuilder.addAggCall(
              sumArgAggCall,
              nGroups,
              newCalls,
              aggCallMapping,
              ImmutableList.of(argType));

    final RexNode sumSquaredArg =
          rexBuilder.makeCall(
              SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);

    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType);
    final RexNode countArg =
        rexBuilder.addAggCall(
            countArgAggCall,
            nGroups,
            newCalls,
            aggCallMapping,
            ImmutableList.of(argType));

    final RexNode avgSumSquaredArg =
        rexBuilder.makeCall(
            SqlStdOperatorTable.DIVIDE,
            sumSquaredArg, countArg);

    final RexNode diff =
        rexBuilder.makeCall(
            SqlStdOperatorTable.MINUS,
            sumArgSquared, avgSumSquaredArg);

    final RexNode denominator;
    if (biased) {
      denominator = countArg;
    } else {
      final RexLiteral one =
          rexBuilder.makeExactLiteral(BigDecimal.ONE);
      final RexNode nul =
          rexBuilder.makeNullLiteral(countArg.getType());
      final RexNode countMinusOne =
          rexBuilder.makeCall(
              SqlStdOperatorTable.MINUS, countArg, one);
      final RexNode countEqOne =
          rexBuilder.makeCall(
              SqlStdOperatorTable.EQUALS, countArg, one);
      denominator =
          rexBuilder.makeCall(
              SqlStdOperatorTable.CASE,
              countEqOne, nul, countMinusOne);
    }

    final SqlOperator divide;
    if (isInferenceEnabled) {
      divide = new DrillSqlOperator(
          "divide",
          2,
          true,
          oldCall.getType(), false);
    } else {
      divide = SqlStdOperatorTable.DIVIDE;
    }

    final RexNode div =
        rexBuilder.makeCall(
            divide, diff, denominator);

    RexNode result = div;
    if (sqrt) {
      final RexNode half =
          rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
      result =
          rexBuilder.makeCall(
              SqlStdOperatorTable.POWER, div, half);
    }

    if (isInferenceEnabled) {
      return result;
    } else {
     /*
      * Currently calcite's strategy to infer the return type of aggregate functions
      * is wrong because it uses the first known argument to determine output type. For
      * instance if we are performing stddev on an integer column then it interprets the
      * output type to be integer which is incorrect as it should be double. So based on
      * this if we add cast after rewriting the aggregate we add an additional cast which
      * would cause wrong results. So we simply add a cast to ANY.
      */
      return rexBuilder.makeCast(
          typeFactory.createSqlType(SqlTypeName.ANY), result);
    }
  }