in core/src/main/scala/org/apache/spark/ml/feature/FastVectorAssembler.scala [26:97]
def this() = this(Identifiable.randomUID("FastVectorAssembler"))
/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: Dataset[_]): DataFrame = {
// Schema transformation.
val schema = dataset.schema
var addedNumericField = false
// Propagate only nominal (categorical) attributes (others only slow down the code)
val attrs: Array[Attribute] = $(inputCols).flatMap { c =>
val field = schema(c)
field.dataType match {
case _: NumericType | BooleanType =>
val attr = Attribute.fromStructField(field)
if (attr.isNominal) {
if (addedNumericField) {
throw new SparkException("Categorical columns must precede all others, column out of order: " + c)
}
Some(attr.withName(c))
} else {
addedNumericField = true
None
}
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.zipWithIndex.map { case (attr, i) =>
if (attr.isNominal && attr.name.isDefined) {
if (addedNumericField) {
throw new SparkException("Categorical columns must precede all others, column out of order: " + c)
}
attr.withName(c + "_" + attr.name.get)
} else if (attr.isNominal) {
if (addedNumericField) {
throw new SparkException("Categorical columns must precede all others, column out of order: " + c)
}
attr.withName(c + "_" + i)
} else {
addedNumericField = true
null
}
}.filter(attr => attr != null)
} else {
addedNumericField = true
None
}
case otherType =>
throw new SparkException(s"FastVectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
// Data transformation.
val assembleFunc = udf { r: Row =>
FastVectorAssembler.assemble(r.toSeq: _*)
}
val args = $(inputCols).map { c =>
schema(c).dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}