in train/comms/pt/pytorch_dist_backend.py [0:0]
def all_reduce(self, collectiveArgs, retFlag=False, pair=False):
# pair=True mode does not support quantization
if (
collectiveArgs.allreduce_qcomm != 32
and collectiveArgs.allreduce_qcomm > 4
and collectiveArgs.ipTensor.dtype == torch.float32
and not pair
):
# note: note that quantized is a new tensor
# that is not collectiveArgs.ipTensor.
# this means when all_reduce/reduce finished
# quantized will hold the result instead of collectiveArgs.ipTensor
# this is intended because we don't want to allocate new buffers
# every time we call all_reduce (because if we don't, it will be float16 instead of float32).
# That also means we can't use the output of quantized all_reduce's for anything other than
# benchmarking purpose.
with paramProfile(
timer=collectiveArgs.quant_time,
description="# PARAM: Allreduce quantization #",
):
quantized = _downcast(
collectiveArgs.ipTensor, collectiveArgs.allreduce_qcomm
)
else:
quantized = (
collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair
)
retObj = dist.all_reduce(
quantized,
op=collectiveArgs.op,
group=collectiveArgs.group,
async_op=collectiveArgs.asyncOp,
) # synchronicity is maintained in runColl
if (id(quantized) != id(collectiveArgs.ipTensor)) and not pair:
if collectiveArgs.asyncOp:
retObj = retObj.get_future().then(_dequantize)
else:
with paramProfile(
timer=collectiveArgs.dequant_time,
description="# PARAM: Allreduce de-quantization #",
):
retObj = _dequantize(quantized)
if collectiveArgs.asyncOp:
collectiveArgs.waitObj.append(retObj)
if retFlag:
return retObj