def get_seq_length()

in optimum/tpu/static_cache_xla.py [0:0]


    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states that were seen by the model."""
        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
        # limit the check to the first batch member and head dimension.
        # TODO: deprecate this function in favor of `cache_position`
        key_cache = self.key_cache[layer_idx]
        device = key_cache.device

        # index_select(dim, index) performs the same operation as item = tensor[..., index, ...]
        # but it is used for better generality and it uses less memory on XLA.
        # For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
        item = key_cache.index_select(0, torch.tensor(0, device=device))
        head = item.index_select(1, torch.tensor(0, device=device))

        return head.any(dim=-1).sum()