in backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExtendedColumnPruning.scala [36:141]
def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
case pj @ Project(projectList, f @ Filter(condition, g: Generate))
if canPruneGenerator(g.generator) &&
GlutenConfig.get.enableExtendedColumnPruning &&
(SQLConf.get.nestedPruningOnExpressions || SQLConf.get.nestedSchemaPruningEnabled) =>
val attrToExtractValues =
getAttributeToExtractValues(projectList ++ g.generator.children :+ condition, Seq.empty)
if (attrToExtractValues.isEmpty) {
return None
}
val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput)
var (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) =
attrToExtractValues.partition {
case (attr, _) =>
attr.references.subsetOf(generatorOutputSet)
}
val pushedThrough = rewritePlanWithAliases(pj, attrToExtractValuesNotOnGenerator)
// We cannot push through if the child of generator is `MapType`.
g.generator.children.head.dataType match {
case _: MapType => return Some(pushedThrough)
case _ =>
}
if (!g.generator.isInstanceOf[ExplodeBase]) {
return Some(pushedThrough)
}
// In spark3.2, we could not reuse [[NestedColumnAliasing.getAttributeToExtractValues]]
// which only accepts 2 arguments. Instead we redefine it in current file to avoid moving
// this rule to gluten-shims
attrToExtractValuesOnGenerator = getAttributeToExtractValues(
attrToExtractValuesOnGenerator.flatMap(_._2).toSeq,
Seq.empty,
collectNestedGetStructFields)
val nestedFieldsOnGenerator = attrToExtractValuesOnGenerator.values.flatten.toSet
if (nestedFieldsOnGenerator.isEmpty) {
return Some(pushedThrough)
}
// Multiple or single nested column accessors.
// E.g. df.select(explode($"items").as("item")).select($"item.a", $"item.b")
pushedThrough match {
case p2 @ Project(_, f2 @ Filter(_, g2: Generate)) =>
val nestedFieldsOnGeneratorSeq = nestedFieldsOnGenerator.toSeq
val nestedFieldToOrdinal = nestedFieldsOnGeneratorSeq.zipWithIndex.toMap
val rewrittenG = g2.transformExpressions {
case e: ExplodeBase =>
val extractors = nestedFieldsOnGeneratorSeq.map(replaceGenerator(e, _))
val names = extractors.map {
case g: GetStructField => Literal(g.extractFieldName)
case ga: GetArrayStructFields => Literal(ga.field.name)
case other =>
throw new IllegalStateException(
s"Unreasonable extractor " +
"after replaceGenerator: $other")
}
val zippedArray = ArraysZip(extractors, names)
e.withNewChildren(Seq(zippedArray))
}
// As we change the child of the generator, its output data type must be updated.
val updatedGeneratorOutput = rewrittenG.generatorOutput
.zip(
rewrittenG.generator.elementSchema.map(
f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()))
.map {
case (oldAttr, newAttr) =>
newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
}
assert(
updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
"Updated generator output must have the same length " +
"with original generator output."
)
val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)
// Replace nested column accessor with generator output.
val attrExprIdsOnGenerator = attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
val updatedFilter = f2.withNewChildren(Seq(updatedGenerate)).transformExpressions {
case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
replaceGetStructField(
f,
updatedGenerate.output,
attrExprIdsOnGenerator,
nestedFieldToOrdinal)
}
val updatedProject = p2.withNewChildren(Seq(updatedFilter)).transformExpressions {
case f: GetStructField if nestedFieldsOnGenerator.contains(f) =>
replaceGetStructField(
f,
updatedFilter.output,
attrExprIdsOnGenerator,
nestedFieldToOrdinal)
}
Some(updatedProject)
case other =>
throw new IllegalStateException(s"Unreasonable plan after optimization: $other")
}
case _ =>
None
}