in src/main/scala/org/apache/spark/shuffle/rss/WriterBufferManager.scala [55:115]
private[rss] def addRecordImpl(partitionId: Int, record: Product2[Any, Any]): Seq[(Int, Array[Byte])] = {
var result: mutable.Buffer[(Int, Array[Byte])] = null
recordsWrittenCount += 1
map.get(partitionId) match {
case Some(v) =>
val stream = v.serializeStream
val oldSize = v.output.position()
stream.writeKey(record._1)
stream.writeValue(record._2)
val newSize = v.output.position()
if (newSize >= bufferSize) {
// partition buffer is full, add it to the result as spill data
if (result == null) {
result = mutable.Buffer[(Int, Array[Byte])]()
}
v.serializeStream.flush()
result.append((partitionId, v.output.toBytes))
v.serializeStream.close()
map.remove(partitionId)
totalBytes -= oldSize
} else {
totalBytes += (newSize - oldSize)
}
case None =>
val output = new Output(bufferSize, maxBufferSize)
val stream = serializerInstance.serializeStream(output)
stream.writeKey(record._1)
stream.writeValue(record._2)
val newSize = output.position()
if (newSize >= bufferSize) {
// partition buffer is full, add it to the result as spill data
if (result == null) {
result = mutable.Buffer[(Int, Array[Byte])]()
}
stream.flush()
result.append((partitionId, output.toBytes))
stream.close()
} else {
map.put(partitionId, WriterBufferManagerValue(stream, output))
totalBytes = totalBytes + newSize
}
}
if (totalBytes >= spillSize) {
// data for all partitions exceeds threshold, add all data to the result as spill data
if (result == null) {
result = mutable.Buffer[(Int, Array[Byte])]()
}
map.values.foreach(_.serializeStream.flush())
result.appendAll(map.map(t=>(t._1, t._2.output.toBytes)))
map.foreach(t => t._2.serializeStream.close())
map.clear()
totalBytes = 0
}
if (result == null) {
Nil
} else {
result
}
}