in client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java [1398:1726]
private void doPushMergedData(
Pair<String, String> addressPair,
int shuffleId,
int mapId,
int attemptId,
ArrayList<DataBatches.DataBatch> batches,
PushState pushState,
int remainReviveTimes) {
String hostPort = addressPair.getLeft();
String[] hostPortArr = Utils.parseColonSeparatedHostPorts(hostPort, 1);
final String host = hostPortArr[0];
final int port = Integer.parseInt(hostPortArr[1]);
int groupedBatchId = pushState.nextBatchId();
pushState.addBatch(groupedBatchId, hostPort);
final int numBatches = batches.size();
final Integer[] partitionIds = new Integer[numBatches];
final String[] partitionUniqueIds = new String[numBatches];
final int[] offsets = new int[numBatches];
final int[] batchIds = new int[numBatches];
int currentSize = 0;
CompositeByteBuf byteBuf = Unpooled.compositeBuffer();
for (int i = 0; i < numBatches; i++) {
DataBatches.DataBatch batch = batches.get(i);
partitionIds[i] = batch.loc.getId();
partitionUniqueIds[i] = batch.loc.getUniqueId();
offsets[i] = currentSize;
batchIds[i] = batch.batchId;
currentSize += batch.body.length;
byteBuf.addComponent(true, Unpooled.wrappedBuffer(batch.body));
}
NettyManagedBuffer buffer = new NettyManagedBuffer(byteBuf);
String shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId);
PushMergedData mergedData =
new PushMergedData(PRIMARY_MODE, shuffleKey, partitionUniqueIds, offsets, buffer);
RpcResponseCallback callback =
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
logger.debug(
"Push merged data to {} success for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
addressPair,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds));
pushState.removeBatch(groupedBatchId, hostPort);
if (response.remaining() > 0 && response.get() == StatusCode.MAP_ENDED.getValue()) {
mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapId);
}
}
@Override
public void onFailure(Throwable e) {
String errorMsg =
String.format(
"Push merged data to %s failed for shuffle %d map %d attempt %d partition %s groupedBatch %d batch %s, remain revive times %d.",
addressPair,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds),
remainReviveTimes);
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
if (logger.isDebugEnabled()) {
for (int i = 0; i < numBatches; i++) {
logger.debug(
"Push merged data to {} failed for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}, remain revive times {}.",
addressPair,
shuffleId,
mapId,
attemptId,
partitionIds[i],
groupedBatchId,
batchIds[i],
remainReviveTimes);
}
}
}
};
RpcResponseCallback wrappedCallback =
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
byte reason = response.get();
if (reason == StatusCode.HARD_SPLIT.getValue()) {
ArrayList<DataBatches.DataBatch> batchesNeedResubmit;
if (response.remaining() > 0) {
batchesNeedResubmit = new ArrayList<>();
PbPushMergedDataSplitPartitionInfo partitionInfo;
try {
partitionInfo = TransportMessage.fromByteBuffer(response).getParsedPayload();
} catch (CelebornIOException | InvalidProtocolBufferException e) {
callback.onFailure(
new CelebornIOException("parse pushMergedData response failed", e));
return;
}
List<Integer> splitPartitionIndexes = partitionInfo.getSplitPartitionIndexesList();
List<Integer> statusCodeList = partitionInfo.getStatusCodesList();
StringBuilder dataBatchReviveInfos = new StringBuilder();
for (int i = 0; i < splitPartitionIndexes.size(); i++) {
int partitionIndex = splitPartitionIndexes.get(i);
int batchId = batches.get(partitionIndex).batchId;
dataBatchReviveInfos.append(
String.format(
"(batchId=%d, partitionId=%d, cause=%s)",
batchId,
partitionIds[partitionIndex],
StatusCode.fromValue(statusCodeList.get(i).byteValue())));
if (statusCodeList.get(i) == StatusCode.SOFT_SPLIT.getValue()) {
PartitionLocation loc = batches.get(partitionIndex).loc;
if (!newerPartitionLocationExists(
reducePartitionMap.get(shuffleId), loc.getId(), loc.getEpoch(), false)) {
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
mapId,
attemptId,
loc.getId(),
loc.getEpoch(),
loc,
StatusCode.SOFT_SPLIT);
reviveManager.addRequest(reviveRequest);
}
} else {
batchesNeedResubmit.add(batches.get(partitionIndex));
}
}
logger.info(
"Push merged data to {} partial success required for shuffle {} map {} attempt {} groupedBatch {}. split batches {}.",
addressPair,
shuffleId,
mapId,
attemptId,
groupedBatchId,
dataBatchReviveInfos);
} else {
// Workers that do not incorporate changes from [CELEBORN-1721]
// will respond with a status of HARD_SPLIT,
// but will not include a PbPushMergedDataSplitPartitionInfo.
// For backward compatibility, all batches must be resubmitted.
batchesNeedResubmit = batches;
logger.info(
"Push merged data to {} hard split required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
addressPair,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds));
}
if (batchesNeedResubmit.isEmpty()) {
pushState.onSuccess(hostPort);
callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SOFT_SPLIT.getValue()}));
} else {
if (dataPushFailureTrackingEnabled && pushReplicateEnabled) {
for (DataBatches.DataBatch resubmitBatch : batchesNeedResubmit) {
pushState.addFailedBatch(
resubmitBatch.loc.getUniqueId(),
new PushFailedBatch(mapId, attemptId, resubmitBatch.batchId));
}
}
ReviveRequest[] requests =
addAndGetReviveRequests(
shuffleId, mapId, attemptId, batchesNeedResubmit, StatusCode.HARD_SPLIT);
pushDataRetryPool.submit(
() ->
submitRetryPushMergedData(
pushState,
shuffleId,
mapId,
attemptId,
batchesNeedResubmit,
StatusCode.HARD_SPLIT,
groupedBatchId,
requests,
remainReviveTimes,
System.currentTimeMillis()
+ conf.clientRpcRequestPartitionLocationAskTimeout()
.duration()
.toMillis()));
}
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
logger.debug(
"Push merged data to {} primary congestion required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
addressPair,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds));
pushState.onCongestControl(hostPort);
callback.onSuccess(response);
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
logger.debug(
"Push merged data to {} replica congestion required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.",
addressPair,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds));
pushState.onCongestControl(hostPort);
callback.onSuccess(response);
} else if (reason == StatusCode.MAP_ENDED.getValue()) {
pushState.onSuccess(hostPort);
callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.MAP_ENDED.getValue()}));
} else { // success
pushState.onSuccess(hostPort);
callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SUCCESS.getValue()}));
}
}
@Override
public void onFailure(Throwable e) {
if (dataPushFailureTrackingEnabled) {
for (int i = 0; i < numBatches; i++) {
pushState.addFailedBatch(
partitionUniqueIds[i], new PushFailedBatch(mapId, attemptId, batchIds[i]));
}
}
if (pushState.exception.get() != null) {
return;
}
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
callback.onFailure(e);
return;
}
StatusCode cause = getPushDataFailCause(e.getMessage());
if (remainReviveTimes <= 0) {
if (e instanceof CelebornIOException) {
callback.onFailure(e);
} else {
callback.onFailure(new CelebornIOException(cause, e));
}
return;
}
logger.error(
"Push merged data to {} failed for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}, remain revive times {}.",
addressPair,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds),
remainReviveTimes,
e);
if (!mapperEnded(shuffleId, mapId)) {
ReviveRequest[] requests =
addAndGetReviveRequests(shuffleId, mapId, attemptId, batches, cause);
pushDataRetryPool.submit(
() ->
submitRetryPushMergedData(
pushState,
shuffleId,
mapId,
attemptId,
batches,
cause,
groupedBatchId,
requests,
remainReviveTimes - 1,
System.currentTimeMillis()
+ conf.clientRpcRequestPartitionLocationAskTimeout()
.duration()
.toMillis()));
} else {
pushState.removeBatch(groupedBatchId, hostPort);
logger.info(
"Push merged data to {} failed but mapper already ended for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}, remain revive times {}.",
hostPort,
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds),
remainReviveTimes);
}
}
};
// do push merged data
try {
if (!isPushTargetWorkerExcluded(batches.get(0).loc, wrappedCallback)) {
if (!testRetryRevive || remainReviveTimes < 1) {
assert dataClientFactory != null;
TransportClient client = dataClientFactory.createClient(host, port);
client.pushMergedData(mergedData, pushDataTimeout, wrappedCallback);
} else {
wrappedCallback.onFailure(
new CelebornIOException(
StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_PRIMARY,
new RuntimeException("Mock push merge data failed.")));
}
}
} catch (Exception e) {
logger.error(
"Exception raised while pushing merged data for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {} location {}.",
shuffleId,
mapId,
attemptId,
Arrays.toString(partitionIds),
groupedBatchId,
Arrays.toString(batchIds),
addressPair,
e);
if (e instanceof InterruptedException) {
wrappedCallback.onFailure(e);
} else {
wrappedCallback.onFailure(
new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, e));
}
}
}