in integration/spark/src/main/scala/org/apache/spark/sql/optimizer/MVRewrite.scala [922:992]
private def getUpdatedOutputAndPredicateList(groupBy: GroupBy,
outputListMapping: Seq[(NamedExpression, NamedExpression, Option[NamedExpression])]):
(Seq[NamedExpression], Seq[Expression]) = {
val outputList = for ((output1, output2, output3) <- outputListMapping) yield {
output1 match {
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Sum] =>
val aggregate = aggregateExpression.aggregateFunction.asInstanceOf[Sum]
val uFun = aggregate.copy(child = output2)
Alias(aggregateExpression.copy(aggregateFunction = uFun),
output1.name)(exprId = output1.exprId)
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Average] =>
val uFunSum = Sum(output2)
val uFunCount = Sum(output3.get)
val uFunDivide = Divide(Cast(uFunSum, DoubleType), Cast(uFunCount, DoubleType))
Alias(Cast(uFunDivide, DoubleType), output1.name)(exprId = output1.exprId)
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Max] =>
val max = aggregateExpression.aggregateFunction.asInstanceOf[Max]
val uFun = max.copy(child = output2)
Alias(aggregateExpression.copy(aggregateFunction = uFun),
output1.name)(exprId = output1.exprId)
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Min] =>
val min = aggregateExpression.aggregateFunction.asInstanceOf[Min]
val uFun = min.copy(child = output2)
Alias(aggregateExpression.copy(aggregateFunction = uFun),
output1.name)(exprId = output1.exprId)
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Count] ||
aggregateExpression.aggregateFunction.isInstanceOf[Corr] ||
aggregateExpression.aggregateFunction.isInstanceOf[VariancePop] ||
aggregateExpression.aggregateFunction.isInstanceOf[VarianceSamp] ||
aggregateExpression.aggregateFunction.isInstanceOf[StddevSamp] ||
aggregateExpression.aggregateFunction.isInstanceOf[StddevPop] ||
aggregateExpression.aggregateFunction.isInstanceOf[CovSample] ||
aggregateExpression.aggregateFunction.isInstanceOf[Skewness] ||
aggregateExpression.aggregateFunction.isInstanceOf[Kurtosis] ||
aggregateExpression.aggregateFunction.isInstanceOf[CovPopulation] =>
val uFun = Sum(output2)
Alias(aggregateExpression.copy(aggregateFunction = uFun),
output1.name)(exprId = output1.exprId)
case _ =>
if (output1.name != output2.name) {
Alias(output2, output1.name)(exprId = output1.exprId)
} else {
output2
}
}
}
val updatedPredicates = groupBy.predicateList.map {
predicate =>
outputListMapping.find {
case (output1, _, _) =>
output1 match {
case alias: Alias if predicate.isInstanceOf[Alias] =>
alias.child.semanticEquals(predicate.children.head)
case alias: Alias =>
alias.child.semanticEquals(predicate)
case other =>
other.semanticEquals(predicate)
}
} match {
case Some((_, output2, _)) => output2
case _ => predicate
}
}
(outputList, updatedPredicates)
}