in integration/spark/src/main/scala/org/apache/spark/sql/optimizer/MVRewrite.scala [734:920]
private def updatePlan(modularPlan: ModularPlan): ModularPlan = {
modularPlan match {
case select: Select if select.modularPlan.isDefined =>
val planWrapper = select.modularPlan.get.asInstanceOf[MVPlanWrapper]
val plan = planWrapper.modularPlan.asInstanceOf[Select]
val aliasMap = getAliasMap(plan.outputList, select.outputList)
// Update the flagSpec as per the mv table attributes.
val updatedFlagSpec = updateFlagSpec(select, plan, aliasMap, keepAlias = false)
// when the output list contains multiple projection of same column, but relation
// contains distinct columns, mapping may go wrong with columns, so select distinct
val updatedPlanOutputList = getUpdatedOutputList(plan.outputList, select.modularPlan)
val outputList =
for ((output1, output2) <- select.outputList.distinct zip updatedPlanOutputList) yield {
if (output1.name != output2.name) {
Alias(output2, output1.name)(exprId = output1.exprId)
} else {
output2
}
}
plan.copy(outputList = outputList, flags = select.flags, flagSpec = updatedFlagSpec)
.setRewritten()
case select: Select => select.children match {
case Seq(groupBy: GroupBy) if groupBy.modularPlan.isDefined =>
val planWrapper = groupBy.modularPlan.get.asInstanceOf[MVPlanWrapper]
val plan = planWrapper.modularPlan.asInstanceOf[Select]
if (!planWrapper.viewSchema.isRefreshIncremental) {
val aliasMap = getAliasMap(plan.outputList, groupBy.outputList)
// Update the flagSpec as per the mv table attributes.
val updatedFlagSpec = updateFlagSpec(select, plan, aliasMap, keepAlias = false)
val updatedPlanOutputList = getUpdatedOutputList(plan.outputList, groupBy.modularPlan)
val outputList =
for ((output1, output2) <- groupBy.outputList zip updatedPlanOutputList) yield {
if (output1.name != output2.name) {
Alias(output2, output1.name)(exprId = output1.exprId)
} else {
output2
}
}
// Directly keep the relation as child.
select.copy(
outputList = select.outputList.map {
output => outputList.find(_.name.equals(output.name)).get
},
children = Seq(plan),
aliasMap = plan.aliasMap,
flagSpec = updatedFlagSpec).setRewritten()
} else {
val child = updatePlan(groupBy).asInstanceOf[Matchable]
// First find the indices from the child output list.
val outputIndices = select.outputList.map {
output =>
groupBy.outputList.indexWhere {
case alias: Alias if output.isInstanceOf[Alias] =>
alias.child.semanticEquals(output.asInstanceOf[Alias].child)
case alias: Alias if alias.child.semanticEquals(output) =>
true
case other if output.isInstanceOf[Alias] =>
other.semanticEquals(output.asInstanceOf[Alias].child)
case other =>
other.semanticEquals(output) || other.toAttribute.semanticEquals(output)
}
}
// Get the outList from converted child output list using already selected indices
val outputList =
outputIndices.map(child.outputList(_)).zip(select.outputList).map {
case (output1, output2) =>
output1 match {
case alias: Alias if output2.isInstanceOf[Alias] =>
Alias(alias.child, output2.name)(exprId = output2.exprId)
case alias: Alias =>
alias
case other if output2.isInstanceOf[Alias] =>
Alias(other, output2.name)(exprId = output2.exprId)
case other =>
other
}
}
val aliasMap = getAliasMap(child.outputList, groupBy.outputList)
// Update the flagSpec as per the mv table attributes.
val updatedFlagSpec = updateFlagSpec(select, plan, aliasMap, keepAlias = false)
// TODO Remove the unnecessary columns from selection.
// Only keep columns which are required by parent.
select.copy(
outputList = outputList,
inputList = child.outputList,
flagSpec = updatedFlagSpec,
children = Seq(child)).setRewritten()
}
case _ => select
}
case groupBy: GroupBy if groupBy.modularPlan.isDefined =>
val planWrapper = groupBy.modularPlan.get.asInstanceOf[MVPlanWrapper]
val plan = planWrapper.modularPlan.asInstanceOf[Select]
val updatedPlanOutputList = getUpdatedOutputList(plan.outputList, groupBy.modularPlan)
// columnIndex is used to iterate over updatedPlanOutputList. For each avg attribute,
// updatedPlanOutputList has 2 attributes (sum and count) and
// by maintaining index we can increment and access when needed.
var columnIndex = -1
def getColumnName(expression: Expression): String = {
expression match {
case attributeReference: AttributeReference => attributeReference.name
case literal: Literal => literal.value.toString
case _ => ""
}
}
// get column from list having the given aggregate and column name.
def getColumnFromOutputList(updatedPlanOutputList: Seq[NamedExpression], aggregate: String,
colName: String): NamedExpression = {
val nextIndex = columnIndex + 1
if ((nextIndex) < updatedPlanOutputList.size &&
updatedPlanOutputList(nextIndex).name.contains(aggregate) &&
updatedPlanOutputList(nextIndex).name.contains(colName)) {
columnIndex += 1
updatedPlanOutputList(columnIndex)
} else {
updatedPlanOutputList.find(x => x.name.contains(aggregate) &&
x.name.contains(colName)).get
}
}
val outputListMapping = if (groupBy.outputList
.exists(_.sql.contains(CarbonCommonConstants.AVERAGE + "("))) {
// for each avg attribute, updatedPlanOutputList has 2 attributes (sum and count),
// so direct mapping of groupBy.outputList and updatedPlanOutputList is not possible.
// If query has avg, then get the sum, count attributes in the list and map accordingly.
for (exp <- groupBy.outputList) yield {
exp match {
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Average] =>
val colName = getColumnName(aggregateExpression.collectLeaves().head)
val sumAttr = getColumnFromOutputList(updatedPlanOutputList,
CarbonCommonConstants.SUM, colName)
val countAttr = getColumnFromOutputList(updatedPlanOutputList,
CarbonCommonConstants.COUNT, colName)
(exp, sumAttr, Some(countAttr))
case Alias(aggregateExpression: AggregateExpression, _)
if aggregateExpression.aggregateFunction.isInstanceOf[Sum] ||
aggregateExpression.aggregateFunction.isInstanceOf[Count] =>
// If query contains avg aggregate and also sum or count of column,
// duplicate column creation is avoided. The column might have already mapped
// with avg, so search from output list to find the column and map.
val colName = getColumnName(aggregateExpression.collectLeaves().head)
val colAttr = getColumnFromOutputList(updatedPlanOutputList,
aggregateExpression.aggregateFunction.prettyName, colName)
(exp, colAttr, None)
case _ =>
columnIndex += 1
(exp, updatedPlanOutputList(columnIndex), None)
}
}
} else {
(groupBy.outputList, updatedPlanOutputList, List.fill(updatedPlanOutputList.size)(None))
.zipped.toList
}
val (outputList: Seq[NamedExpression], updatedPredicates: Seq[Expression]) =
getUpdatedOutputAndPredicateList(
groupBy,
outputListMapping)
groupBy.copy(
outputList = outputList,
inputList = plan.outputList,
predicateList = updatedPredicates,
child = plan,
modularPlan = None).setRewritten()
case groupBy: GroupBy if groupBy.predicateList.nonEmpty => groupBy.child match {
case select: Select if select.modularPlan.isDefined =>
val planWrapper = select.modularPlan.get.asInstanceOf[MVPlanWrapper]
val plan = planWrapper.modularPlan.asInstanceOf[Select]
val updatedPlanOutputList = getUpdatedOutputList(plan.outputList, select.modularPlan)
val outputListMapping = (groupBy.outputList, updatedPlanOutputList, List.fill(
updatedPlanOutputList.size)(None)).zipped.toList
val (outputList: Seq[NamedExpression], updatedPredicates: Seq[Expression]) =
getUpdatedOutputAndPredicateList(
groupBy,
outputListMapping)
groupBy.copy(
outputList = outputList,
inputList = plan.outputList,
predicateList = updatedPredicates,
child = select,
modularPlan = None)
case _ => groupBy
}
case other => other
}
}