in torchrec/distributed/dist_data.py [0:0]
def _wait_impl(self) -> KeyedJaggedTensor:
if self._workers == 1:
self._input.sync()
return self._input
self._values_awaitable.wait()
if self._weights_awaitable:
self._weights_awaitable.wait()
keys = self._keys
lengths = self._lengths
values = self._values
weights = self._weights
with record_function("## all2all_data:recat_values ##"):
if self._recat.numel():
lengths, values, weights = torch.ops.fbgemm.permute_sparse_data(
self._recat,
lengths.view(self._workers * self._splits[self._pg.rank()], -1),
values,
weights,
values.numel(),
)
lengths = lengths.view(-1)
ret = KeyedJaggedTensor.from_lengths_sync(
keys=keys,
values=values,
weights=weights,
lengths=lengths,
stride=self._workers * self._input.stride(),
)
return ret