def write()

in spark-connector/hive/src/main/scala/org/apache/spark/sql/hive/execution/OdpsTableWriter.scala [54:215]


  def write(
      sparkSession: SparkSession,
      plan: SparkPlan,
      writeSession: TableBatchWriteSession,
      description: WriteJobDescription,
      outputColumns: Seq[Attribute],
      bucketSpec: Option[BucketSpec],
      bucketAttributes: Seq[Attribute],
      bucketSortOrders: Seq[SortOrder],
      overwrite: Boolean)
  : Unit = {
    val dynamicPartitionColumns = description.dynamicPartitionColumns
    // We should first sort by partition columns, then bucket id, and finally sorting columns.
    val requiredOrdering = dynamicPartitionColumns
    // the sort order doesn't matter
    val actualOrdering = plan.outputOrdering.map(_.child)
    val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
      false
    } else {
      requiredOrdering.zip(actualOrdering).forall {
        case (requiredOrder, childOutputOrder) =>
          requiredOrder.semanticEquals(childOutputOrder)
      }
    }

    SQLExecution.checkSQLExecutionId(sparkSession)

    val identifier = writeSession.getTableIdentifier

    val tempRdd = bucketSpec match {
      case Some(BucketSpec(numBuckets, _, _)) =>
        val shuffledRdd = new OdpsShuffleExchangeExec(
          OdpsHashPartitioning(bucketAttributes, numBuckets), plan)

        if (bucketSortOrders.nonEmpty || dynamicPartitionColumns.nonEmpty) {
          val orderingExpr = if (dynamicPartitionColumns.nonEmpty) {
            dynamicPartitionColumns.map(SortOrder(_, Ascending)) ++ bucketSortOrders
          } else {
            bucketSortOrders
          }
            .map(BindReferences.bindReference(_, outputColumns))
          SortExec(
            orderingExpr,
            global = false,
            child = shuffledRdd
          ).execute()
        } else {
          shuffledRdd.execute()
        }

      case _ =>
        if (orderingMatched) {
          plan.execute()
        } else {
          val orderingExpr = requiredOrdering
            .map(SortOrder(_, Ascending))
            .map(BindReferences.bindReference(_, outputColumns))
          SortExec(
            orderingExpr,
            global = false,
            child = plan).execute()
        }
    }

    val rdd: RDD[InternalRow] = {
      // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
      // partition rdd to make sure we at least set up one write task to write the metadata.
      if (tempRdd.partitions.length == 0) {
        sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1)
      } else {
        tempRdd
      }
    }

    val writerFactory = new OdpsWriterFactory(description)
    val useCommitCoordinator = true
    val messages = new Array[WriterCommitMessage](rdd.partitions.length)
    val totalNumRowsAccumulator = new LongAccumulator()
    val customMetrics: Map[String, SQLMetric] = Map.empty

    try {
      sparkSession.sparkContext.runJob(
        rdd,
        (context: TaskContext, iter: Iterator[InternalRow]) =>
          DataWritingTask.run(writerFactory, context, iter, useCommitCoordinator,
            customMetrics),
        rdd.partitions.indices,
        (index, result: DataWritingTaskResult) => {
          val commitMessage = result.writerCommitMessage
          messages(index) = commitMessage
          totalNumRowsAccumulator.add(result.numRows)
        }
      )

      logInfo(s"Data source write $identifier is committing.")

      val results = messages.map(_.asInstanceOf[WriteTaskResult])
      val commitMessageList = results.map(_.commitMessage).reduceOption(_ ++ _).getOrElse(Seq.empty).asJava
      try {
        val (_, duration) = Utils.timeTakenMs {
          writeSession.commit(commitMessageList.toArray(Array.empty[OdpsWriterCommitMessage]))
        }
        processStats(description.statsTrackers, results.map(_.stats), duration)
        logInfo(s"Data source write $identifier committed. Elapsed time: $duration ms.")
      } catch {
        case cause: Throwable =>
          if (commitMessageList.stream.filter(_ != null).count == 0) {
            val partition = description.staticPartition.toString
            if (!description.staticPartition.isEmpty &&
              dynamicPartitionColumns.isEmpty) {
              val sb = new StringBuilder
              sb.append("ALTER TABLE ")
                .append(identifier.getProject)
                .append(".")
                .append(identifier.getTable)
                .append(" ADD IF NOT EXISTS PARTITION (")
                .append(partition)
                .append(");")
              val instance = SQLTask.run(OdpsClient.builder.getOrCreate.odps, sb.toString)
              instance.waitForSuccess()
              logInfo(s"Data source write $identifier committed empty data. " +
                s"Try to create partition $partition")
            }
            if (overwrite && dynamicPartitionColumns.isEmpty) {
              val sb = new StringBuilder
              sb.append("TRUNCATE TABLE ")
                .append(identifier.getProject)
                .append(".")
                .append(identifier.getTable)
              if (!description.staticPartition.isEmpty) {
                sb.append(" PARTITION (")
                sb.append(partition)
                sb.append(")")
              }
              sb.append(";")
              val instance = SQLTask.run(OdpsClient.builder.getOrCreate.odps, sb.toString)
              instance.waitForSuccess()
              logInfo(s"Data source write $identifier committed empty data. Truncate table.")
            }
          } else {
            throw cause
          }
      }
    } catch {
      case cause: Throwable =>
        logError(s"Data source write $identifier is aborting.")
        try {
          writeSession.cleanup()
        } catch {
          case t: Throwable =>
            logError(s"Data source write $identifier failed to abort.")
            cause.addSuppressed(t)
            throw QueryExecutionErrors.writingJobFailedError(cause)
        }
        logError(s"Data source write $identifier aborted.")
        cause match {
          // Only wrap non fatal exceptions.
          case NonFatal(e) => throw QueryExecutionErrors.writingJobAbortedError(e)
          case _ => throw cause
        }
    }
  }