in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala [1076:1292]
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = CodeGenerator.javaType(loopVar.dataType)
val loopVarCode = LambdaVariable.prepareLambdaVariable(ctx, loopVar)
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
// array declarations (i.e. new String[1][]).
val arrayConstructor = if (convertedType contains "[]") {
val rawType = convertedType.takeWhile(_ != '[')
val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse
s"new $rawType[$dataLength]$arrayPart"
} else {
s"new $convertedType[$dataLength]"
}
// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
// of input collection at runtime for this case.
val seq = ctx.freshName("seq")
val array = ctx.freshName("array")
val determineCollectionType = inputData.dataType match {
case ObjectType(cls) if cls == classOf[Object] =>
val seqClass = classOf[scala.collection.Seq[_]].getName
s"""
$seqClass $seq = null;
$elementJavaType[] $array = null;
if (${genInputData.value}.getClass().isArray()) {
$array = ($elementJavaType[]) ${genInputData.value};
} else {
$seq = ($seqClass) ${genInputData.value};
}
"""
case _ => ""
}
// `MapObjects` generates a while loop to traverse the elements of the input collection. We
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
val (getLength, prepareLoop, getLoopVar) = inputDataType match {
case ObjectType(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
val it = ctx.freshName("it")
(
s"${genInputData.value}.size()",
s"scala.collection.Iterator $it = ${genInputData.value}.iterator();",
s"$it.next()"
)
case ObjectType(cls) if cls.isArray =>
(
s"${genInputData.value}.length",
"",
s"${genInputData.value}[$loopIndex]"
)
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
val it = ctx.freshName("it")
(
s"${genInputData.value}.size()",
s"java.util.Iterator $it = ${genInputData.value}.iterator();",
s"$it.next()"
)
case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
val it = ctx.freshName("it")
(
s"${genInputData.value}.size()",
s"java.util.Iterator $it = ${genInputData.value}.iterator();",
s"$it.next()"
)
case ArrayType(et, _) =>
(
s"${genInputData.value}.numElements()",
"",
CodeGenerator.getValue(genInputData.value, et, loopIndex)
)
case ObjectType(cls) if cls == classOf[Object] =>
val it = ctx.freshName("it")
(
s"$seq == null ? $array.length : $seq.size()",
s"scala.collection.Iterator $it = $seq == null ? null : $seq.iterator();",
s"$it == null ? $array[$loopIndex] : $it.next()"
)
}
// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
val genFunctionValue: String = lambdaFunction.dataType match {
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
case _ => genFunction.value
}
val loopNullCheck = if (loopVar.nullable) {
inputDataType match {
case _: ArrayType => s"${loopVarCode.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
case _ => s"${loopVarCode.isNull} = ${loopVarCode.value} == null;"
}
} else {
""
}
val (initCollection, addElement, getResult): (String, String => String, String) =
customCollectionCls match {
case Some(cls) if classOf[mutable.ArraySeq[_]].isAssignableFrom(cls) =>
val tag = ctx.addReferenceObj("tag", elementClassTag())
val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
val builder = ctx.freshName("collectionBuilder")
(
s"""
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
(genValue: String) => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) ${classOf[mutable.ArraySeq[_]].getName}$$." +
s"MODULE$$.make($builder.result());"
)
case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) =>
val tag = ctx.addReferenceObj("tag", elementClassTag())
val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
val builder = ctx.freshName("collectionBuilder")
(
s"""
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
(genValue: String) => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) ${classOf[immutable.ArraySeq[_]].getName}$$." +
s"MODULE$$.unsafeWrapArray($builder.result());"
)
case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) ||
classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala sequence or set
val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
val builder = ctx.freshName("collectionBuilder")
(
s"""
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
(genValue: String) => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) $builder.result();"
)
case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
// Java list
val builder = ctx.freshName("collectionBuilder")
(
if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
cls == classOf[java.util.AbstractSequentialList[_]]) {
s"${cls.getName} $builder = new java.util.ArrayList($dataLength);"
} else {
val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("")
s"${cls.getName} $builder = new ${cls.getName}($param);"
},
(genValue: String) => s"$builder.add($genValue);",
s"$builder;"
)
case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
// Java set
val builder = ctx.freshName("collectionBuilder")
(
if (cls == classOf[java.util.Set[_]] || cls == classOf[java.util.AbstractSet[_]]) {
s"${cls.getName} $builder = new java.util.HashSet($dataLength);"
} else {
val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("")
s"${cls.getName} $builder = new ${cls.getName}($param);"
},
(genValue: String) => s"$builder.add($genValue);",
s"$builder;"
)
case _ =>
// array
(
s"""
$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;
""",
(genValue: String) => s"$convertedArray[$loopIndex] = $genValue;",
s"new ${classOf[GenericArrayData].getName}($convertedArray);"
)
}
val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
$determineCollectionType
int $dataLength = $getLength;
$initCollection
int $loopIndex = 0;
$prepareLoop
while ($loopIndex < $dataLength) {
${loopVarCode.value} = ($elementJavaType) ($getLoopVar);
$loopNullCheck
${genFunction.code}
if (${genFunction.isNull}) {
${addElement("null")}
} else {
${addElement(genFunctionValue)}
}
$loopIndex += 1;
}
${ev.value} = $getResult
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
}