override def shortName()

in spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala [43:81]


  override def shortName(): String = "tfrecords"

  // Writes DataFrame as TensorFlow Records
  override def createRelation(
    sqlContext: SQLContext,
    mode: SaveMode,
    parameters: Map[String, String],
    data: DataFrame): BaseRelation = {

    val path = parameters("path")
    val codec = parameters.getOrElse("codec", "")

    val recordType = parameters.getOrElse("recordType", "Example")

    //Export DataFrame as TFRecords
    val features = data.rdd.map(row => {
      recordType match {
        case "Example" =>
          val example = DefaultTfRecordRowEncoder.encodeExample(row)
          (new BytesWritable(example.toByteArray), NullWritable.get())
        case "SequenceExample" =>
          val sequenceExample = DefaultTfRecordRowEncoder.encodeSequenceExample(row)
          (new BytesWritable(sequenceExample.toByteArray), NullWritable.get())
        case _ =>
          throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample")
      }
    })

    parameters.getOrElse("writeLocality", "distributed") match {
      case "distributed" =>
        saveDistributed(features, path, sqlContext, mode, codec)
      case "local" =>
        saveLocal(features, path, mode, codec)
      case s: String =>
        throw new IllegalArgumentException(
          s"Expected 'distributed' or 'local', got $s")
    }
    TensorflowRelation(parameters)(sqlContext.sparkSession)
  }