in spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala [91:111]
def inferFeatureListTypes(schemaSoFar: mutable.Map[String, DataType],
featureListMap: mutable.Map[String, FeatureList]): mutable.Map[String, DataType] = {
featureListMap.foreach {
case (featureName, featureList) => {
val featureType = featureList.getFeatureList.asScala.map(f => inferField(f))
.reduceLeft((a, b) => findTightestCommonType(a, b).orNull)
val currentType = featureType match {
case ArrayType(_, _) => ArrayType(featureType)
case _ => ArrayType(ArrayType(featureType))
}
if (schemaSoFar.contains(featureName)) {
val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType)
schemaSoFar(featureName) = updatedType.orNull
}
else {
schemaSoFar += (featureName -> currentType)
}
}
}
schemaSoFar
}