def forward()

in torchrec/distributed/grouped_position_weighted.py [0:0]


    def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
        if features.weights_or_none() is None:
            cat_seq = torch.ops.fbgemm.offsets_range(
                features.offsets().long(), torch.numel(features.values())
            )
        else:
            # for row-wise sharding
            cat_seq = features.weights().long()
        seqs = torch.split(cat_seq, features.length_per_key())
        weights_list = []
        for key, seq in zip(features.keys(), seqs):
            if key in self.max_feature_lengths:
                weights_list.append(
                    torch.gather(self.position_weights[key], dim=0, index=seq)
                )
            else:
                weights_list.append(
                    self._dummy_weights[: self.max_feature_lengths[key]]
                )
        weights = torch.cat(weights_list)

        return KeyedJaggedTensor(
            keys=features.keys(),
            values=features.values(),
            weights=weights,
            lengths=features.lengths(),
            offsets=features.offsets(),
            stride=features.stride(),
            length_per_key=features.length_per_key(),
        )