def _wait_impl()

in torchrec/distributed/dist_data.py [0:0]


    def _wait_impl(self) -> KeyedJaggedTensor:
        if self._workers == 1:
            self._input.sync()
            return self._input

        self._values_awaitable.wait()

        if self._weights_awaitable:
            self._weights_awaitable.wait()

        keys = self._keys
        lengths = self._lengths
        values = self._values
        weights = self._weights

        with record_function("## all2all_data:recat_values ##"):
            if self._recat.numel():
                lengths, values, weights = torch.ops.fbgemm.permute_sparse_data(
                    self._recat,
                    lengths.view(self._workers * self._splits[self._pg.rank()], -1),
                    values,
                    weights,
                    values.numel(),
                )
                lengths = lengths.view(-1)

        ret = KeyedJaggedTensor.from_lengths_sync(
            keys=keys,
            values=values,
            weights=weights,
            lengths=lengths,
            stride=self._workers * self._input.stride(),
        )
        return ret