def inferFeatureListTypes()

in spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala [91:111]


  def inferFeatureListTypes(schemaSoFar: mutable.Map[String, DataType],
                            featureListMap: mutable.Map[String, FeatureList]): mutable.Map[String, DataType] = {
    featureListMap.foreach {
      case (featureName, featureList) => {
        val featureType = featureList.getFeatureList.asScala.map(f => inferField(f))
          .reduceLeft((a, b) => findTightestCommonType(a, b).orNull)
        val currentType = featureType match {
          case ArrayType(_, _) => ArrayType(featureType)
          case _ => ArrayType(ArrayType(featureType))
        }
        if (schemaSoFar.contains(featureName)) {
          val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType)
          schemaSoFar(featureName) = updatedType.orNull
        }
        else {
          schemaSoFar += (featureName -> currentType)
        }
      }
    }
    schemaSoFar
  }