in optimum/tpu/static_cache_xla.py [0:0]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
key_cache = self.key_cache[layer_idx]
device = key_cache.device
# index_select(dim, index) performs the same operation as item = tensor[..., index, ...]
# but it is used for better generality and it uses less memory on XLA.
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
item = key_cache.index_select(0, torch.tensor(0, device=device))
head = item.index_select(1, torch.tensor(0, device=device))
return head.any(dim=-1).sum()