in client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java [764:1114]
public int pushOrMergeData(
int shuffleId,
int mapId,
int attemptId,
int partitionId,
byte[] data,
int offset,
int length,
int numMappers,
int numPartitions,
boolean doPush)
throws IOException {
// mapKey
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
// return if shuffle stage already ended
if (mapperEnded(shuffleId, mapId)) {
logger.debug(
"Push or merge data ignored because mapper already ended for shuffle {} map {} attempt {} partition {}.",
shuffleId,
mapId,
attemptId,
partitionId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
pushState.cleanup();
}
return 0;
}
// register shuffle if not registered
final ConcurrentHashMap<Integer, PartitionLocation> map =
getPartitionLocation(shuffleId, numMappers, numPartitions);
if (map == null) {
throw new CelebornIOException("Register shuffle failed for shuffle " + shuffleId + ".");
}
// get location
// If rerun or speculation task running after LifecycleManager call stageEnd,
// register shuffle will return an empty location map, client need revive for a new location.
if (!map.containsKey(partitionId)) {
if (!revive(
shuffleId,
mapId,
attemptId,
partitionId,
-1,
null,
StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
throw new CelebornIOException(
String.format("Revive for shuffle %s partition %d failed.", shuffleId, partitionId));
}
}
if (mapperEnded(shuffleId, mapId)) {
logger.debug(
"Push or merge data ignored because mapper already ended for shuffle {} map {} attempt {} partition {}.",
shuffleId,
mapId,
attemptId,
partitionId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
pushState.cleanup();
}
return 0;
}
final PartitionLocation loc = map.get(partitionId);
if (loc == null) {
throw new CelebornIOException(
String.format(
"Partition location for shuffle %s partition %d is NULL!", shuffleId, partitionId));
}
PushState pushState = getPushState(mapKey);
// increment batchId
final int nextBatchId = pushState.nextBatchId();
if (shuffleCompressionEnabled) {
// compress data
final Compressor compressor = compressorThreadLocal.get();
compressor.compress(data, offset, length);
data = compressor.getCompressedBuffer();
offset = 0;
length = compressor.getCompressedTotalSize();
}
final byte[] body = new byte[BATCH_HEADER_SIZE + length];
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, nextBatchId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length);
System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length);
if (doPush) {
// check limit
limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort());
// add inFlight requests
pushState.addBatch(nextBatchId, loc.hostAndPushPort());
// build PushData request
NettyManagedBuffer buffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body));
final String shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId);
PushData pushData = new PushData(PRIMARY_MODE, shuffleKey, loc.getUniqueId(), buffer);
// build callback
RpcResponseCallback callback =
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
if (response.remaining() > 0 && response.get() == StatusCode.STAGE_ENDED.getValue()) {
stageEndShuffleSet.add(shuffleId);
}
logger.debug(
"Push data to {} success for shuffle {} map {} attempt {} partition {} batch {}.",
loc.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
}
@Override
public void onFailure(Throwable e) {
String errorMsg =
String.format(
"Push data to %s failed for shuffle %d map %d attempt %d partition %d batch %d.",
loc, shuffleId, mapId, attemptId, partitionId, nextBatchId);
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
}
};
RpcResponseCallback wrappedCallback =
new PushDataRpcResponseCallback() {
int remainReviveTimes = maxReviveTimes;
PartitionLocation latest = loc;
@Override
public void updateLatestPartition(PartitionLocation latest) {
pushState.addBatch(nextBatchId, latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, this.latest.hostAndPushPort());
this.latest = latest;
}
@Override
public void onSuccess(ByteBuffer response) {
if (response.remaining() > 0) {
byte reason = response.get();
if (reason == StatusCode.SOFT_SPLIT.getValue()) {
logger.debug(
"Push data to {} soft split required for shuffle {} map {} attempt {} partition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
splitPartition(shuffleId, partitionId, latest);
pushState.onSuccess(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
} else if (reason == StatusCode.HARD_SPLIT.getValue()) {
logger.debug(
"Push data to {} hard split required for shuffle {} map {} attempt {} partition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
mapId,
attemptId,
partitionId,
latest.getEpoch(),
latest,
StatusCode.HARD_SPLIT);
reviveManager.addRequest(reviveRequest);
long dueTime =
System.currentTimeMillis()
+ conf.clientRpcRequestPartitionLocationRpcAskTimeout()
.duration()
.toMillis();
pushDataRetryPool.submit(
() ->
submitRetryPushData(
shuffleId,
body,
nextBatchId,
this,
pushState,
reviveRequest,
remainReviveTimes,
dueTime));
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
logger.debug(
"Push data to {} primary congestion required for shuffle {} map {} attempt {} partition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
pushState.onCongestControl(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
logger.debug(
"Push data to {} replica congestion required for shuffle {} map {} attempt {} partition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
pushState.onCongestControl(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
} else {
// StageEnd.
response.rewind();
pushState.onSuccess(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
}
} else {
pushState.onSuccess(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
}
}
@Override
public void onFailure(Throwable e) {
StatusCode cause = getPushDataFailCause(e.getMessage());
if (pushState.exception.get() != null) {
return;
}
if (remainReviveTimes <= 0) {
if (e instanceof CelebornIOException) {
callback.onFailure(e);
} else {
callback.onFailure(new CelebornIOException(cause, e));
}
return;
}
logger.error(
"Push data to {} failed for shuffle {} map {} attempt {} partition {} batch {}, remain revive times {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId,
remainReviveTimes,
e);
// async retry push data
if (!mapperEnded(shuffleId, mapId)) {
remainReviveTimes = remainReviveTimes - 1;
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId, mapId, attemptId, partitionId, latest.getEpoch(), latest, cause);
reviveManager.addRequest(reviveRequest);
long dueTime =
System.currentTimeMillis()
+ conf.clientRpcRequestPartitionLocationRpcAskTimeout()
.duration()
.toMillis();
pushDataRetryPool.submit(
() ->
submitRetryPushData(
shuffleId,
body,
nextBatchId,
this,
pushState,
reviveRequest,
remainReviveTimes,
dueTime));
} else {
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
logger.info(
"Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} partition {} batch {}, remain revive times {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId,
remainReviveTimes);
}
}
};
// do push data
try {
if (!isPushTargetWorkerExcluded(loc, wrappedCallback)) {
if (!testRetryRevive) {
TransportClient client =
dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId);
client.pushData(pushData, pushDataTimeout, wrappedCallback);
} else {
wrappedCallback.onFailure(
new CelebornIOException(
StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE,
new RuntimeException("Mock push data first time failed.")));
}
}
} catch (Exception e) {
logger.error(
"Exception raised while pushing data for shuffle {} map {} attempt {} partition {} batch {} location {}.",
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId,
loc,
e);
wrappedCallback.onFailure(
new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
}
} else {
// add batch data
logger.debug("Merge batch {}.", nextBatchId);
Pair<String, String> addressPair = genAddressPair(loc);
boolean shouldPush = pushState.addBatchData(addressPair, loc, nextBatchId, body);
if (shouldPush) {
limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort());
DataBatches dataBatches = pushState.takeDataBatches(addressPair);
doPushMergedData(
addressPair,
shuffleId,
mapId,
attemptId,
dataBatches.requireBatches(),
pushState,
maxReviveTimes);
}
}
return body.length;
}