protected def newWriter()

in core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala [308:536]


  protected def newWriter(
      env: SparkEnv,
      worker: PythonWorker,
      inputIterator: Iterator[IN],
      partitionIndex: Int,
      context: TaskContext): Writer

  protected def newReaderIterator(
      stream: DataInputStream,
      writer: Writer,
      startTime: Long,
      env: SparkEnv,
      worker: PythonWorker,
      pid: Option[Int],
      releasedOrClosed: AtomicBoolean,
      context: TaskContext): Iterator[OUT]

  /**
   * Responsible for writing the data from the PythonRDD's parent iterator to the
   * Python process.
   */
  abstract class Writer(
      env: SparkEnv,
      worker: PythonWorker,
      inputIterator: Iterator[IN],
      partitionIndex: Int,
      context: TaskContext) {

    @volatile private var _exception: Throwable = _

    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

    /** Contains the throwable thrown while writing the parent iterator to the Python process. */
    def exception: Option[Throwable] = Option(_exception)

    /**
     * Writes a command section to the stream connected to the Python worker.
     */
    protected def writeCommand(dataOut: DataOutputStream): Unit

    /**
     * Writes input data to the stream connected to the Python worker.
     * Returns true if any data was written to the stream, false if the input is exhausted.
     */
    def writeNextInputToStream(dataOut: DataOutputStream): Boolean

    def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
      val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
      lazy val sockPath = new File(
        authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
          .getOrElse(System.getProperty("java.io.tmpdir")),
        s".${UUID.randomUUID()}.sock")
      try {
        // Partition index
        dataOut.writeInt(partitionIndex)

        PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)

        // Init a ServerSocket to accept method calls from Python side.
        val isBarrier = context.isInstanceOf[BarrierTaskContext]
        if (isBarrier) {
          if (isUnixDomainSock) {
            serverSocketChannel = Some(ServerSocketChannel.open(StandardProtocolFamily.UNIX))
            sockPath.deleteOnExit()
            serverSocketChannel.get.bind(UnixDomainSocketAddress.of(sockPath.getPath))
          } else {
            serverSocketChannel = Some(ServerSocketChannel.open())
            serverSocketChannel.foreach(_.bind(
              new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1))
            // A call to accept() for ServerSocket shall block infinitely.
            serverSocketChannel.foreach(_.socket().setSoTimeout(0))
          }

          new Thread("accept-connections") {
            setDaemon(true)

            override def run(): Unit = {
              while (serverSocketChannel.get.isOpen()) {
                var sock: SocketChannel = null
                try {
                  sock = serverSocketChannel.get.accept()
                  // Wait for function call from python side.
                  if (!isUnixDomainSock) sock.socket().setSoTimeout(10000)
                  authHelper.authClient(sock)
                  val input = new DataInputStream(Channels.newInputStream(sock))
                  val requestMethod = input.readInt()
                  // The BarrierTaskContext function may wait infinitely, socket shall not timeout
                  // before the function finishes.
                  if (!isUnixDomainSock) sock.socket().setSoTimeout(0)
                  requestMethod match {
                    case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
                      barrierAndServe(requestMethod, sock)
                    case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
                      val message = PythonWorkerUtils.readUTF(input)
                      barrierAndServe(requestMethod, sock, message)
                    case _ =>
                      val out = new DataOutputStream(new BufferedOutputStream(
                        Channels.newOutputStream(sock)))
                      writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
                  }
                } catch {
                  case _: AsynchronousCloseException =>
                    // Ignore to make less noisy. These will be closed when tasks
                    // are finished by listeners.
                    if (isUnixDomainSock) sockPath.delete()
                } finally {
                  if (sock != null) {
                    sock.close()
                  }
                }
              }
            }
          }.start()
        }
        if (isBarrier) {
          // Close ServerSocket on task completion.
          serverSocketChannel.foreach { server =>
            context.addTaskCompletionListener[Unit] { _ =>
              server.close()
              if (isUnixDomainSock) sockPath.delete()
            }
          }
          if (isUnixDomainSock) {
            logDebug(s"Started ServerSocket on with Unix Domain Socket $sockPath.")
            dataOut.writeBoolean(/* isBarrier = */true)
            dataOut.writeInt(-1)
            PythonRDD.writeUTF(sockPath.getPath, dataOut)
          } else {
            val boundPort: Int = serverSocketChannel.map(_.socket().getLocalPort).getOrElse(-1)
            if (boundPort == -1) {
              val message = "ServerSocket failed to bind to Java side."
              logError(message)
              throw new SparkException(message)
            }
            logDebug(s"Started ServerSocket on port $boundPort.")
            dataOut.writeBoolean(/* isBarrier = */true)
            dataOut.writeInt(boundPort)
            PythonRDD.writeUTF(authHelper.secret, dataOut)
          }
        } else {
          dataOut.writeBoolean(/* isBarrier = */false)
        }
        // Write out the TaskContextInfo
        dataOut.writeInt(context.stageId())
        dataOut.writeInt(context.partitionId())
        dataOut.writeInt(context.attemptNumber())
        dataOut.writeLong(context.taskAttemptId())
        dataOut.writeInt(context.cpus())
        val resources = context.resources()
        dataOut.writeInt(resources.size)
        resources.foreach { case (k, v) =>
          PythonRDD.writeUTF(k, dataOut)
          PythonRDD.writeUTF(v.name, dataOut)
          dataOut.writeInt(v.addresses.length)
          v.addresses.foreach { case addr =>
            PythonRDD.writeUTF(addr, dataOut)
          }
        }
        val localProps = context.getLocalProperties.asScala
        dataOut.writeInt(localProps.size)
        localProps.foreach { case (k, v) =>
          PythonRDD.writeUTF(k, dataOut)
          PythonRDD.writeUTF(v, dataOut)
        }

        PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
        PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)

        dataOut.writeInt(evalType)
        writeCommand(dataOut)

        dataOut.flush()
      } catch {
        case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>
          if (context.isCompleted() || context.isInterrupted()) {
            logDebug("Exception/NonFatal Error thrown after task completion (likely due to " +
              "cleanup)", t)
            if (worker.channel.isConnected) {
              Utils.tryLog(worker.channel.shutdownOutput())
            }
          } else {
            // We must avoid throwing exceptions/NonFatals here, because the thread uncaught
            // exception handler will kill the whole executor (see
            // org.apache.spark.executor.Executor).
            _exception = t
            if (worker.channel.isConnected) {
              Utils.tryLog(worker.channel.shutdownOutput())
            }
          }
      }
    }

    def close(dataOut: DataOutputStream): Unit = {
      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
      dataOut.flush()
    }

    /**
     * Gateway to call BarrierTaskContext methods.
     */
    def barrierAndServe(requestMethod: Int, sock: SocketChannel, message: String = ""): Unit = {
      require(
        serverSocketChannel.isDefined,
        "No available ServerSocket to redirect the BarrierTaskContext method call."
      )
      val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(sock)))
      try {
        val messages = requestMethod match {
          case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
            context.asInstanceOf[BarrierTaskContext].barrier()
            Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
          case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
            context.asInstanceOf[BarrierTaskContext].allGather(message)
        }
        out.writeInt(messages.length)
        messages.foreach(writeUTF(_, out))
      } catch {
        case e: SparkException =>
          writeUTF(e.getMessage, out)
      } finally {
        out.close()
      }
    }

    def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
      PythonWorkerUtils.writeUTF(str, dataOut)
    }
  }