in train/comms/pt/comms_utils.py [0:0]
def prepComm(self, curComm, commsParams):
"""Allocate the tensors for collective"""
commOp = paramToCommName(
curComm["comms"] if ("comms" in curComm.keys()) else commsParams.collective,
supported_comms=self.backendFuncs.collectiveFunc.keys(),
)
if commOp in ("wait", "barrier"):
return ([], [])
numElementsIn = curComm["in_msg_size"]
# numElementsOut is only meaningful for out-of-place collectives and pt2pt
numElementsOut = curComm["out_msg_size"]
world_size = self.collectiveArgs.world_size
dtype = commsParams.dtype
curDevice = commsParams.device
# scaleFactor = 1 if commsParams.collective == "all_to_all" else numElements * numElements
scaleFactor = numElementsOut * numElementsOut
opTensor = []
if commsParams.dcheck == 1:
# use predictable values for data validation check
ipTensor = self.backendFuncs.alloc_ones(
[numElementsIn], curDevice, dtype, scaleFactor=self.initVal
)
else:
ipTensor = self.backendFuncs.alloc_random(
[numElementsIn], curDevice, dtype, scaleFactor
)
if commOp == "all_to_allv":
# all_to_all(v) requires two tensors
opTensor = self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
# all_to_allv requires tensors to specify split
self.collectiveArgs.opTensor_split = (
curComm["out_split"] if ("out_split" in curComm.keys()) else []
)
self.collectiveArgs.ipTensor_split = (
curComm["in_split"] if ("in_split" in curComm.keys()) else []
)
elif commOp == "all_gather":
# allgather requires a tensor list, e.g., List[torch.Tensor]
for _ in range(world_size):
opTensor.append(
self.backendFuncs.alloc_random(
[numElementsIn], curDevice, dtype, scaleFactor
)
)
elif commOp == "all_gather_base":
# this is a single all gather with flat output tensor
opTensor = self.backendFuncs.alloc_random(
numElementsIn * world_size,
curDevice,
dtype,
scaleFactor,
)
elif commOp == "incast":
# incast requires a tensor list with length of src_ranks, e.g., List[torch.Tensor]
for _ in self.collectiveArgs.src_ranks:
opTensor.append(
self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
)
elif commOp == "reduce_scatter":
ipTensor = []
if commsParams.dcheck == 1:
for _ in range(world_size):
ipTensor.append(
self.backendFuncs.alloc_ones(
[numElementsOut], curDevice, commsParams.dtype, self.initVal
)
)
else:
for _ in range(world_size):
ipTensor.append(
self.backendFuncs.alloc_random(
[numElementsOut], curDevice, commsParams.dtype, scaleFactor
)
)
opTensor = self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
elif commOp == "reduce_scatter_base":
ipTensor = []
if commsParams.dcheck == 1:
ipTensor = self.backendFuncs.alloc_ones(
numElementsOut * world_size,
curDevice,
commsParams.dtype,
self.initVal,
)
else:
ipTensor = self.backendFuncs.alloc_random(
numElementsOut * world_size,
curDevice,
commsParams.dtype,
scaleFactor,
)
opTensor = self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
elif commOp in ("all_to_all", "pt2pt"):
# pt2pt or out-of-place collectives
opTensor = self.backendFuncs.alloc_random(
[numElementsOut],
curDevice,
dtype,
scaleFactor,
)
else:
# in-place case for other collectives such as allreduce, reduce, broadcast
opTensor = ipTensor
return (ipTensor, opTensor)