in spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala [34:57]
def apply[T : TypeTag](rdd: RDD[T]): StructType = {
val startType: mutable.Map[String, DataType] = mutable.Map.empty[String, DataType]
val rootTypes: mutable.Map[String, DataType] = typeOf[T] match {
case t if t =:= typeOf[Example] => {
rdd.asInstanceOf[RDD[Example]].aggregate(startType)(inferExampleRowType, mergeFieldTypes)
}
case t if t =:= typeOf[SequenceExample] => {
rdd.asInstanceOf[RDD[SequenceExample]].aggregate(startType)(inferSequenceExampleRowType, mergeFieldTypes)
}
case _ => throw new IllegalArgumentException(s"Unsupported recordType: recordType can be Example or SequenceExample")
}
val columnsList = rootTypes.map {
case (featureName, featureType) =>
if (featureType == null) {
StructField(featureName, StringType)
}
else {
StructField(featureName, featureType)
}
}
StructType(columnsList.toSeq)
}