in train/comms/pt/pytorch_dist_backend.py [0:0]
def all_to_allv(self, collectiveArgs, retFlag=False, pair=False):
# pair=True mode does not support quantization
if (
collectiveArgs.all2all_qcomm
and collectiveArgs.ipTensor.dtype == torch.float32
and (
collectiveArgs.opTensor.nelement() >= collectiveArgs.quant_threshold
or collectiveArgs.ipTensor.nelement() >= collectiveArgs.quant_threshold
)
and not pair
):
work = all_to_allv_internal(collectiveArgs)
else:
work = dist.all_to_all_single(
collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair,
collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair,
collectiveArgs.opTensor_split
if not pair
else collectiveArgs.opTensor_split_pair,
collectiveArgs.ipTensor_split
if not pair
else collectiveArgs.ipTensor_split_pair,
group=collectiveArgs.group,
async_op=collectiveArgs.asyncOp,
)
if collectiveArgs.asyncOp:
collectiveArgs.waitObj.append(work)
if retFlag:
return work