def SparseDataDist()

in train/comms/pt/dlrm.py [0:0]


    def SparseDataDist(self, num_features_per_rank, input_features, global_rank, world_size, timers):
        if len(num_features_per_rank) == 1:
            return

        batch_size = input_features.batch_size
        device = input_features.lengths.device
        cpu_device = torch.device("cpu")

        # Start input_lengths_per_feature copy to host.
        input_lengths_per_feature = (
            input_features.lengths.view(input_features.count, input_features.batch_size)
            .sum(dim=1)
            .to(cpu_device, non_blocking=True)
        )

        # Distribute lengths first
        # as we need to know output lengths to indices and weights.
        # Then distribute indices and weights in parallel.
        num_my_features = num_features_per_rank[global_rank]
        output_lengths = torch.empty(
            num_my_features * batch_size * world_size,
            device=device,
            dtype=input_features.lengths.dtype,
        )
        #with record_function("## all2all_data:lengths ##"):
        out_splits = [num_my_features * batch_size] * world_size
        in_splits = [
            num_features * batch_size
            for num_features in num_features_per_rank
        ]
        self.collectiveArgs.opTensor = output_lengths
        self.collectiveArgs.ipTensor = input_features.lengths
        self.collectiveArgs.opTensor_split = out_splits
        self.collectiveArgs.ipTensor_split = in_splits
        self.collectiveArgs.asyncOp = False
        self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)

        timers['offset_xchg_start'] = time.monotonic()
        self.backendFuncs.all_to_allv(self.collectiveArgs)
        self.backendFuncs.complete_accel_ops(self.collectiveArgs)
        timers['offset_xchg_end'] = time.monotonic()

        cur_iter_memory = output_lengths.element_size() * output_lengths.nelement()
        self.measured_regions['offset_xchg']['memory'].append(cur_iter_memory)

        prev_a2a_details = {
            "comms" : "all_to_all",
            "msg_size" : cur_iter_memory,
            "in_split" : in_splits,
            "out_split" : out_splits,
            "dtype" : str(input_features.lengths.dtype),
        }
        self.commDetails.append(prev_a2a_details)

        # Start alltoall request for 'indices'.
        output_indices = torch.empty(
            output_lengths.sum().item(),
            device=device,
            dtype=torch.int64  # input_features.indices.dtype,
        )
        output_indices_splits = (
            output_lengths.view(world_size, -1)
            .sum(dim=1)
            .to(cpu_device)
            .numpy()
        )

        input_features_splits = []
        feature_offset = 0
        input_lengths_per_feature = input_lengths_per_feature.numpy()
        for feature_count in num_features_per_rank:
            feature_length = sum(
                input_lengths_per_feature[
                    feature_offset : feature_offset + feature_count
                ]
            )
            input_features_splits.append(feature_length)
            feature_offset += feature_count

        self.collectiveArgs.opTensor = output_indices
        self.collectiveArgs.ipTensor = input_features.indices
        self.collectiveArgs.opTensor_split = output_indices_splits
        self.collectiveArgs.ipTensor_split = input_features_splits
        self.collectiveArgs.asyncOp = False

        self.backendFuncs.complete_accel_ops(self.collectiveArgs, initOp=True)
        timers['idx_xchg_start'] = time.monotonic()
        self.backendFuncs.all_to_allv(self.collectiveArgs)
        self.backendFuncs.complete_accel_ops(self.collectiveArgs)
        timers['idx_xchg_end'] = time.monotonic()

        cur_iter_memory = output_indices.element_size() * output_indices.nelement()
        self.measured_regions['idx_xchg']['memory'].append(cur_iter_memory)
        cur_a2a_details = {
            "comms" : "all_to_all",
            "msg_size" : cur_iter_memory,
            "in_split" : np.array(input_features_splits, dtype=np.int32).tolist(),
            "out_split" : np.array(output_indices_splits, dtype=np.int32).tolist(),
            "dtype" : str(input_features.indices.dtype)
        }
        self.commDetails.append(cur_a2a_details)

        # By now we have received the lengths, and indices of local-table for the global-batch
        # We need to split the lengths and indices per table -- logic explained in splitPerTable
        offsets, indices = self.paramNN.splitPerTable(output_lengths, output_indices, batch_size, num_my_features, world_size, global_rank, device)
        return (offsets, indices)