in spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala [43:81]
override def shortName(): String = "tfrecords"
// Writes DataFrame as TensorFlow Records
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val path = parameters("path")
val codec = parameters.getOrElse("codec", "")
val recordType = parameters.getOrElse("recordType", "Example")
//Export DataFrame as TFRecords
val features = data.rdd.map(row => {
recordType match {
case "Example" =>
val example = DefaultTfRecordRowEncoder.encodeExample(row)
(new BytesWritable(example.toByteArray), NullWritable.get())
case "SequenceExample" =>
val sequenceExample = DefaultTfRecordRowEncoder.encodeSequenceExample(row)
(new BytesWritable(sequenceExample.toByteArray), NullWritable.get())
case _ =>
throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample")
}
})
parameters.getOrElse("writeLocality", "distributed") match {
case "distributed" =>
saveDistributed(features, path, sqlContext, mode, codec)
case "local" =>
saveLocal(features, path, mode, codec)
case s: String =>
throw new IllegalArgumentException(
s"Expected 'distributed' or 'local', got $s")
}
TensorflowRelation(parameters)(sqlContext.sparkSession)
}