in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/protobuf/ProtobufConverter.scala [54:80]
def rowToProtobuf(row: Row, featuresFieldName: String,
labelFieldName: Option[String] = Option.empty): Record = {
require(row.schema != null, "Row schema is null for row " + row)
val protobufBuilder : Builder = Record.newBuilder()
if (labelFieldName.nonEmpty) {
val hasLabelColumn = row.schema.fieldNames.contains(labelFieldName.get)
if (hasLabelColumn) {
setLabel(protobufBuilder, row.getAs[Double](labelFieldName.get))
}
}
val hasFeaturesColumn = row.schema.fieldNames.contains(featuresFieldName)
if (hasFeaturesColumn) {
val idx = row.fieldIndex(featuresFieldName)
val target = row.get(idx) match {
case v : Vector =>
setFeatures(protobufBuilder, v)
case m : Matrix =>
setFeatures(protobufBuilder, m)
}
} else if (!hasFeaturesColumn) {
throw new IllegalArgumentException(s"Need a features column with a " +
s"Vector of doubles named $featuresFieldName to convert row to protobuf")
}
protobufBuilder.build
}