in torchrec/sparse/jagged_tensor.py [0:0]
def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
split_list: List[KeyedJaggedTensor] = []
start = 0
start_offset = 0
_length_per_key = self.length_per_key()
_offset_per_key = self.offset_per_key()
for segment in segments:
end = start + segment
end_offset = _offset_per_key[end]
keys: List[str] = self._keys[start:end]
if segment == len(self._keys):
# no torch slicing required
split_list.append(
KeyedJaggedTensor(
keys=self._keys,
values=self._values,
weights=self.weights_or_none(),
lengths=self._lengths,
offsets=self._offsets,
stride=self._stride,
length_per_key=self._length_per_key,
offset_per_key=self._offset_per_key,
index_per_key=self._index_per_key,
jt_dict=self._jt_dict,
)
)
elif segment == 0:
split_list.append(
KeyedJaggedTensor(
keys=keys,
values=torch.tensor(
[], device=self.device(), dtype=self._values.dtype
),
weights=None
if self.weights_or_none() is None
else torch.tensor(
[],
device=self.device(),
dtype=self.weights().dtype,
),
lengths=torch.tensor([], device=self.device(), dtype=torch.int),
offsets=torch.tensor([], device=self.device(), dtype=torch.int),
stride=self._stride,
length_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
)
)
else:
split_length_per_key = _length_per_key[start:end]
split_list.append(
KeyedJaggedTensor(
keys=keys,
values=self._values[start_offset:end_offset],
weights=None
if self.weights_or_none() is None
else self.weights()[start_offset:end_offset],
lengths=self.lengths()[
start * self._stride : end * self._stride
],
offsets=None,
stride=self._stride,
length_per_key=split_length_per_key,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
)
)
start = end
start_offset = end_offset
return split_list