in exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java [490:662]
private RexNode reduceStddev(
Aggregate oldAggRel,
AggregateCall oldCall,
boolean biased,
boolean sqrt,
List<AggregateCall> newCalls,
Map<AggregateCall, RexNode> aggCallMapping,
List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory =
oldAggRel.getCluster().getTypeFactory();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType =
getFieldType(
oldAggRel.getInput(),
argOrdinal);
// final RexNode argRef = inputExprs.get(argOrdinal);
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
inputExprs.set(argOrdinal, argRef);
final RexNode argSquared =
rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
RelDataType sumType =
TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
ImmutableList.of())
.inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, true);
final AggregateCall sumArgSquaredAggCall =
AggregateCall.create(
new DrillCalciteSqlAggFunctionWrapper(
new SqlSumAggFunction(sumType), sumType),
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
ImmutableIntList.of(argSquaredOrdinal),
oldCall.filterArg,
oldCall.distinctKeys,
oldCall.getCollation(),
sumType,
null);
final RexNode sumArgSquared =
rexBuilder.addAggCall(
sumArgSquaredAggCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
final AggregateCall sumArgAggCall =
AggregateCall.create(
new DrillCalciteSqlAggFunctionWrapper(
new SqlSumAggFunction(sumType), sumType),
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
ImmutableIntList.of(argOrdinal),
oldCall.filterArg,
oldCall.distinctKeys,
oldCall.getCollation(),
sumType,
null);
final RexNode sumArg =
rexBuilder.addAggCall(
sumArgAggCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
final RexNode sumSquaredArg =
rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType);
final RexNode countArg =
rexBuilder.addAggCall(
countArgAggCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
final RexNode avgSumSquaredArg =
rexBuilder.makeCall(
SqlStdOperatorTable.DIVIDE,
sumSquaredArg, countArg);
final RexNode diff =
rexBuilder.makeCall(
SqlStdOperatorTable.MINUS,
sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one =
rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul =
rexBuilder.makeNullLiteral(countArg.getType());
final RexNode countMinusOne =
rexBuilder.makeCall(
SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne =
rexBuilder.makeCall(
SqlStdOperatorTable.EQUALS, countArg, one);
denominator =
rexBuilder.makeCall(
SqlStdOperatorTable.CASE,
countEqOne, nul, countMinusOne);
}
final SqlOperator divide;
if (isInferenceEnabled) {
divide = new DrillSqlOperator(
"divide",
2,
true,
oldCall.getType(), false);
} else {
divide = SqlStdOperatorTable.DIVIDE;
}
final RexNode div =
rexBuilder.makeCall(
divide, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half =
rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result =
rexBuilder.makeCall(
SqlStdOperatorTable.POWER, div, half);
}
if (isInferenceEnabled) {
return result;
} else {
/*
* Currently calcite's strategy to infer the return type of aggregate functions
* is wrong because it uses the first known argument to determine output type. For
* instance if we are performing stddev on an integer column then it interprets the
* output type to be integer which is incorrect as it should be double. So based on
* this if we add cast after rewriting the aggregate we add an additional cast which
* would cause wrong results. So we simply add a cast to ANY.
*/
return rexBuilder.makeCast(
typeFactory.createSqlType(SqlTypeName.ANY), result);
}
}