def split()

in torchrec/sparse/jagged_tensor.py [0:0]


    def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
        split_list: List[KeyedJaggedTensor] = []
        start = 0
        start_offset = 0
        _length_per_key = self.length_per_key()
        _offset_per_key = self.offset_per_key()
        for segment in segments:
            end = start + segment
            end_offset = _offset_per_key[end]
            keys: List[str] = self._keys[start:end]
            if segment == len(self._keys):
                # no torch slicing required
                split_list.append(
                    KeyedJaggedTensor(
                        keys=self._keys,
                        values=self._values,
                        weights=self.weights_or_none(),
                        lengths=self._lengths,
                        offsets=self._offsets,
                        stride=self._stride,
                        length_per_key=self._length_per_key,
                        offset_per_key=self._offset_per_key,
                        index_per_key=self._index_per_key,
                        jt_dict=self._jt_dict,
                    )
                )
            elif segment == 0:
                split_list.append(
                    KeyedJaggedTensor(
                        keys=keys,
                        values=torch.tensor(
                            [], device=self.device(), dtype=self._values.dtype
                        ),
                        weights=None
                        if self.weights_or_none() is None
                        else torch.tensor(
                            [],
                            device=self.device(),
                            dtype=self.weights().dtype,
                        ),
                        lengths=torch.tensor([], device=self.device(), dtype=torch.int),
                        offsets=torch.tensor([], device=self.device(), dtype=torch.int),
                        stride=self._stride,
                        length_per_key=None,
                        offset_per_key=None,
                        index_per_key=None,
                        jt_dict=None,
                    )
                )
            else:
                split_length_per_key = _length_per_key[start:end]
                split_list.append(
                    KeyedJaggedTensor(
                        keys=keys,
                        values=self._values[start_offset:end_offset],
                        weights=None
                        if self.weights_or_none() is None
                        else self.weights()[start_offset:end_offset],
                        lengths=self.lengths()[
                            start * self._stride : end * self._stride
                        ],
                        offsets=None,
                        stride=self._stride,
                        length_per_key=split_length_per_key,
                        offset_per_key=None,
                        index_per_key=None,
                        jt_dict=None,
                    )
                )
            start = end
            start_offset = end_offset
        return split_list