in train/comms/pt/dlrm.py [0:0]
def SparseDataDist(self, num_features_per_rank, input_features, global_rank, world_size, timers):
if len(num_features_per_rank) == 1:
return
batch_size = input_features.batch_size
device = input_features.lengths.device
cpu_device = torch.device("cpu")
# Start input_lengths_per_feature copy to host.
input_lengths_per_feature = (
input_features.lengths.view(input_features.count, input_features.batch_size)
.sum(dim=1)
.to(cpu_device, non_blocking=True)
)
# Distribute lengths first
# as we need to know output lengths to indices and weights.
# Then distribute indices and weights in parallel.
num_my_features = num_features_per_rank[global_rank]
output_lengths = torch.empty(
num_my_features * batch_size * world_size,
device=device,
dtype=input_features.lengths.dtype,
)
#with record_function("## all2all_data:lengths ##"):
out_splits = [num_my_features * batch_size] * world_size
in_splits = [
num_features * batch_size
for num_features in num_features_per_rank
]
self.collectiveArgs.opTensor = output_lengths
self.collectiveArgs.ipTensor = input_features.lengths
self.collectiveArgs.opTensor_split = out_splits
self.collectiveArgs.ipTensor_split = in_splits
self.collectiveArgs.asyncOp = False
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
timers['offset_xchg_start'] = time.monotonic()
self.backendFuncs.all_to_allv(self.collectiveArgs)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
timers['offset_xchg_end'] = time.monotonic()
cur_iter_memory = output_lengths.element_size() * output_lengths.nelement()
self.measured_regions['offset_xchg']['memory'].append(cur_iter_memory)
prev_a2a_details = {
"comms" : "all_to_all",
"msg_size" : cur_iter_memory,
"in_split" : in_splits,
"out_split" : out_splits,
"dtype" : str(input_features.lengths.dtype),
}
self.commDetails.append(prev_a2a_details)
# Start alltoall request for 'indices'.
output_indices = torch.empty(
output_lengths.sum().item(),
device=device,
dtype=torch.int64 # input_features.indices.dtype,
)
output_indices_splits = (
output_lengths.view(world_size, -1)
.sum(dim=1)
.to(cpu_device)
.numpy()
)
input_features_splits = []
feature_offset = 0
input_lengths_per_feature = input_lengths_per_feature.numpy()
for feature_count in num_features_per_rank:
feature_length = sum(
input_lengths_per_feature[
feature_offset : feature_offset + feature_count
]
)
input_features_splits.append(feature_length)
feature_offset += feature_count
self.collectiveArgs.opTensor = output_indices
self.collectiveArgs.ipTensor = input_features.indices
self.collectiveArgs.opTensor_split = output_indices_splits
self.collectiveArgs.ipTensor_split = input_features_splits
self.collectiveArgs.asyncOp = False
self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
timers['idx_xchg_start'] = time.monotonic()
self.backendFuncs.all_to_allv(self.collectiveArgs)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
timers['idx_xchg_end'] = time.monotonic()
cur_iter_memory = output_indices.element_size() * output_indices.nelement()
self.measured_regions['idx_xchg']['memory'].append(cur_iter_memory)
cur_a2a_details = {
"comms" : "all_to_all",
"msg_size" : cur_iter_memory,
"in_split" : np.array(input_features_splits, dtype=np.int32).tolist(),
"out_split" : np.array(output_indices_splits, dtype=np.int32).tolist(),
"dtype" : str(input_features.indices.dtype)
}
self.commDetails.append(cur_a2a_details)
# By now we have received the lengths, and indices of local-table for the global-batch
# We need to split the lengths and indices per table -- logic explained in splitPerTable
offsets, indices = self.paramNN.splitPerTable(output_lengths, output_indices, batch_size, num_my_features, world_size, global_rank, device)
return (offsets, indices)