in torchrec/distributed/grouped_position_weighted.py [0:0]
def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
if features.weights_or_none() is None:
cat_seq = torch.ops.fbgemm.offsets_range(
features.offsets().long(), torch.numel(features.values())
)
else:
# for row-wise sharding
cat_seq = features.weights().long()
seqs = torch.split(cat_seq, features.length_per_key())
weights_list = []
for key, seq in zip(features.keys(), seqs):
if key in self.max_feature_lengths:
weights_list.append(
torch.gather(self.position_weights[key], dim=0, index=seq)
)
else:
weights_list.append(
self._dummy_weights[: self.max_feature_lengths[key]]
)
weights = torch.cat(weights_list)
return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
weights=weights,
lengths=features.lengths(),
offsets=features.offsets(),
stride=features.stride(),
length_per_key=features.length_per_key(),
)