in train/comms/pt/comms.py [0:0]
def initCollectiveArgs(self, commsParams):
# lint was complaining that benchTime was too complex!
(
local_rank,
global_rank,
world_size,
group,
curDevice,
curHwDevice,
) = comms_utils.get_rank_details(
self.backendFuncs
) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.
self.backendFuncs.sayHello() # Informs us where each process is running.
groups = self.backendFuncs.get_groups()
num_pgs = len(groups)
self.comm_size = world_size
self.global_rank = global_rank
comms_utils.fixBeginSize(
commsParams, world_size
) # Ensuring that all-reduce and all-to-all has atleast one member per rank.
allSizes = comms_utils.getSizes(
commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
) # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.
if global_rank == 0:
print(
f"[Rank {global_rank:>3}] allSizes: {allSizes} local_rank: {local_rank} element_size: {commsParams.element_size}"
)
self.collectiveArgs.group = group
self.collectiveArgs.groups = groups
self.collectiveArgs.num_pgs = num_pgs
self.collectiveArgs.device = curDevice
self.collectiveArgs.world_size = world_size
self.collectiveArgs.numIters = commsParams.numIters
self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters
self.collectiveArgs.global_rank = global_rank
self.collectiveArgs.backendFuncs = self.backendFuncs
self.collectiveArgs.collective = commsParams.collective
op = self.backendFuncs.get_reduce_op("sum")
self.collectiveArgs.op = op
self.collectiveArgs.srcOrDst = commsParams.srcOrDst
self.collectiveArgs.src_ranks = commsParams.src_ranks
self.collectiveArgs.dst_ranks = commsParams.dst_ranks
self.collectiveArgs.pair = commsParams.pair
self.collectiveArgs.collective_pair = commsParams.collective_pair
self.collectiveArgs.pt2pt = commsParams.pt2pt
self.collectiveArgs.window = commsParams.window
self.collectiveArgs.asyncOp = False if commsParams.blockingFlag == 1 else True
if commsParams.bitwidth < 32:
comms_utils.initQuantCommCtx(self.collectiveArgs, commsParams)
if self.collectiveArgs.collective == "pt2pt":
self.checkPt2PtRanks()
else:
self.checkCollectiveRanks()
computeFunc = self.backendFuncs.noop
if (
commsParams.mode != "comms"
): # Compute mode related initialization if not in comms-only mode
if commsParams.kernel == "gemm":
computeFunc = self.backendFuncs.gemm
mm_dim = commsParams.mm_dim
in1 = np.random.rand(mm_dim, mm_dim)
MMin1 = torch.FloatTensor(in1).to(curDevice)
in2 = np.random.rand(mm_dim, mm_dim)
MMin2 = torch.FloatTensor(in2).to(curDevice)
in3 = np.random.rand(mm_dim, mm_dim)
MMin3 = torch.FloatTensor(in3).to(curDevice)
MMout = self.backendFuncs.alloc_empty(
[mm_dim, mm_dim], commsParams.dtype, curDevice
)
self.collectiveArgs.MMout = MMout
self.collectiveArgs.MMin1 = MMin1
self.collectiveArgs.MMin2 = MMin2
self.collectiveArgs.MMin3 = MMin3
self.collectiveArgs.numComputePerColl = commsParams.num_compute
elif commsParams.kernel == "emb_lookup":
computeFunc = self.backendFuncs.emb_lookup
emb_dim = commsParams.emb_dim
num_embeddings = commsParams.num_embs
avg_length = commsParams.avg_len
batch_size = commsParams.batch_size
print(
f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}"
)
self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty(
[num_embeddings, emb_dim], torch.double, curDevice
)
self.collectiveArgs.TableOffsets = torch.LongTensor(
[0, num_embeddings]
).to(curDevice)
self.collectiveArgs.Indices = torch.LongTensor(
np.random.randint(0, num_embeddings - 1, avg_length * batch_size)
).to(curDevice)
lengths = np.ones((1, batch_size)) * avg_length
flat_lengths = lengths.flatten()
self.collectiveArgs.Offsets = torch.LongTensor(
[0] + np.cumsum(flat_lengths).tolist()
).to(curDevice)
self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty(
[batch_size, emb_dim], torch.double, curDevice
)
self.collectiveArgs.AvgLengths = avg_length
self.collectiveArgs.numComputePerColl = commsParams.num_compute
return (
local_rank,
global_rank,
world_size,
group,
curDevice,
curHwDevice,
allSizes,
computeFunc,
)