in spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder.scala [153:200]
def encodeFeatureList(row: Row, structField: StructField, index: Int): FeatureList = {
val featureList = structField.dataType match {
case ArrayType(ArrayType(IntegerType, _), _) =>
val longArrays = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toIntArray().map(_.toLong).toSeq
}
Int64FeatureListEncoder.encode(longArrays)
case ArrayType(ArrayType(LongType, _), _) =>
val longArrays = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toLongArray().toSeq
}
Int64FeatureListEncoder.encode(longArrays)
case ArrayType(ArrayType(FloatType, _), _) =>
val floatArrays = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toFloatArray().toSeq
}
FloatFeatureListEncoder.encode(floatArrays)
case ArrayType(ArrayType(DoubleType, _), _) =>
val floatArrays = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toDoubleArray().map(_.toFloat).toSeq
}
FloatFeatureListEncoder.encode(floatArrays)
case ArrayType(ArrayType(DecimalType(), _), _) =>
val floatArrays = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toArray[Decimal](DataTypes.createDecimalType()).map(_.toFloat).toSeq
}
FloatFeatureListEncoder.encode(floatArrays)
case ArrayType(ArrayType(StringType, _), _) =>
val arrayData = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toArray[String](ObjectType(classOf[String])).toSeq.map(_.getBytes)
}.toSeq
BytesFeatureListEncoder.encode(arrayData)
case ArrayType(ArrayType(BinaryType, _), _) =>
val arrayData = ArrayData.toArrayData(row.get(index)).array.map {arr =>
ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).toSeq
}.toSeq
BytesFeatureListEncoder.encode(arrayData)
case _ => throw new RuntimeException(s"Cannot convert row element ${row.get(index)} to FeatureList.")
}
featureList
}