in core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala [308:536]
protected def newWriter(
env: SparkEnv,
worker: PythonWorker,
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext): Writer
protected def newReaderIterator(
stream: DataInputStream,
writer: Writer,
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]
/**
* Responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
abstract class Writer(
env: SparkEnv,
worker: PythonWorker,
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext) {
@volatile private var _exception: Throwable = _
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
/** Contains the throwable thrown while writing the parent iterator to the Python process. */
def exception: Option[Throwable] = Option(_exception)
/**
* Writes a command section to the stream connected to the Python worker.
*/
protected def writeCommand(dataOut: DataOutputStream): Unit
/**
* Writes input data to the stream connected to the Python worker.
* Returns true if any data was written to the stream, false if the input is exhausted.
*/
def writeNextInputToStream(dataOut: DataOutputStream): Boolean
def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
lazy val sockPath = new File(
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
.getOrElse(System.getProperty("java.io.tmpdir")),
s".${UUID.randomUUID()}.sock")
try {
// Partition index
dataOut.writeInt(partitionIndex)
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
if (isUnixDomainSock) {
serverSocketChannel = Some(ServerSocketChannel.open(StandardProtocolFamily.UNIX))
sockPath.deleteOnExit()
serverSocketChannel.get.bind(UnixDomainSocketAddress.of(sockPath.getPath))
} else {
serverSocketChannel = Some(ServerSocketChannel.open())
serverSocketChannel.foreach(_.bind(
new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1))
// A call to accept() for ServerSocket shall block infinitely.
serverSocketChannel.foreach(_.socket().setSoTimeout(0))
}
new Thread("accept-connections") {
setDaemon(true)
override def run(): Unit = {
while (serverSocketChannel.get.isOpen()) {
var sock: SocketChannel = null
try {
sock = serverSocketChannel.get.accept()
// Wait for function call from python side.
if (!isUnixDomainSock) sock.socket().setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(Channels.newInputStream(sock))
val requestMethod = input.readInt()
// The BarrierTaskContext function may wait infinitely, socket shall not timeout
// before the function finishes.
if (!isUnixDomainSock) sock.socket().setSoTimeout(0)
requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
barrierAndServe(requestMethod, sock)
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
val message = PythonWorkerUtils.readUTF(input)
barrierAndServe(requestMethod, sock, message)
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
Channels.newOutputStream(sock)))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
case _: AsynchronousCloseException =>
// Ignore to make less noisy. These will be closed when tasks
// are finished by listeners.
if (isUnixDomainSock) sockPath.delete()
} finally {
if (sock != null) {
sock.close()
}
}
}
}
}.start()
}
if (isBarrier) {
// Close ServerSocket on task completion.
serverSocketChannel.foreach { server =>
context.addTaskCompletionListener[Unit] { _ =>
server.close()
if (isUnixDomainSock) sockPath.delete()
}
}
if (isUnixDomainSock) {
logDebug(s"Started ServerSocket on with Unix Domain Socket $sockPath.")
dataOut.writeBoolean(/* isBarrier = */true)
dataOut.writeInt(-1)
PythonRDD.writeUTF(sockPath.getPath, dataOut)
} else {
val boundPort: Int = serverSocketChannel.map(_.socket().getLocalPort).getOrElse(-1)
if (boundPort == -1) {
val message = "ServerSocket failed to bind to Java side."
logError(message)
throw new SparkException(message)
}
logDebug(s"Started ServerSocket on port $boundPort.")
dataOut.writeBoolean(/* isBarrier = */true)
dataOut.writeInt(boundPort)
PythonRDD.writeUTF(authHelper.secret, dataOut)
}
} else {
dataOut.writeBoolean(/* isBarrier = */false)
}
// Write out the TaskContextInfo
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
dataOut.writeLong(context.taskAttemptId())
dataOut.writeInt(context.cpus())
val resources = context.resources()
dataOut.writeInt(resources.size)
resources.foreach { case (k, v) =>
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v.name, dataOut)
dataOut.writeInt(v.addresses.length)
v.addresses.foreach { case addr =>
PythonRDD.writeUTF(addr, dataOut)
}
}
val localProps = context.getLocalProperties.asScala
dataOut.writeInt(localProps.size)
localProps.foreach { case (k, v) =>
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}
PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
dataOut.writeInt(evalType)
writeCommand(dataOut)
dataOut.flush()
} catch {
case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>
if (context.isCompleted() || context.isInterrupted()) {
logDebug("Exception/NonFatal Error thrown after task completion (likely due to " +
"cleanup)", t)
if (worker.channel.isConnected) {
Utils.tryLog(worker.channel.shutdownOutput())
}
} else {
// We must avoid throwing exceptions/NonFatals here, because the thread uncaught
// exception handler will kill the whole executor (see
// org.apache.spark.executor.Executor).
_exception = t
if (worker.channel.isConnected) {
Utils.tryLog(worker.channel.shutdownOutput())
}
}
}
}
def close(dataOut: DataOutputStream): Unit = {
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
}
/**
* Gateway to call BarrierTaskContext methods.
*/
def barrierAndServe(requestMethod: Int, sock: SocketChannel, message: String = ""): Unit = {
require(
serverSocketChannel.isDefined,
"No available ServerSocket to redirect the BarrierTaskContext method call."
)
val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(sock)))
try {
val messages = requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
context.asInstanceOf[BarrierTaskContext].barrier()
Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
context.asInstanceOf[BarrierTaskContext].allGather(message)
}
out.writeInt(messages.length)
messages.foreach(writeUTF(_, out))
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
} finally {
out.close()
}
}
def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
PythonWorkerUtils.writeUTF(str, dataOut)
}
}