in optimum/neuron/models/inference/backend/modules/autobucketing.py [0:0]
def context_encoder_bk(tensors: List[torch.Tensor], buckets, padding_side: str, pad_token: int):
"""
The Bucket Kernel for Context Encoding Models.
1) tensors: A list of torch tensors after running through the flattener
2) buckets: A torch.tensor of the bucket sizes
3) padding_side: A string specifying padding side, must be "left" or "right"
4) pad_token: An integer representing the pad token id. Typically this is 0.
"""
input_ids = tensors[0]
# -----Remarks for calculating position_idx-----
# finds the number of non pad tokens and that is the active sequence_length
# The resulting tensor is of shape (batch_size,)
#
# NOTE: We derive position_ids from input_ids because
# position_ids is eliminated from the flattener for context encoding models.
# ----------------------------------------------
position_idx = (input_ids != pad_token).sum(dim=1)
position_idx = position_idx[:, None] # shape (batch_size, 1)
buckets = buckets[None, :] # shape (1, seq_len)
# -----Remarks for choosing the bucket_idx-----
# 1. (buckets < position_idx) produces a bucket_mask where invalid buckets are 0
# 2. We convert the boolean tensor to int because argmin doesn't support
# boolean tensors
# 3. We choose the minimum valid bucket, which is the first 1 value
# 4. From the minimum valid buckets, we choose the largest bucket, otherwise
# we'd be truncating generated tokens from longer sequences.
# 5. DO NOT USE argmax since we monkeypatch it,
# causing issues with torch.jit.script
# ---------------------------------------------
bucket_mask = (buckets < position_idx).to(torch.int) # shape (batch_size, seq_len)
bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1))
# select the chosen bucket after squeezing back to original form
bucket = buckets.squeeze(0)[bucket_idx]
new_tensors = []
# ---------Remarks on handling padding sides-------
# 1. slice from the opposite side for padding
# 2. Identify seq_id tensors by shape and don't slice it
# -------------------------------------------------
if padding_side == "right":
for i, tens in enumerate(tensors):
# identifies the seq_ids, which don't need to be sliced
if len(tens.shape) == 1:
new_tensors.append(tens)
else: # all other tensors are of shape (batch_size,seq_len) so we slice on seq_len
new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=0, end=bucket))
else:
max_idx = buckets[-1][-1]
for i, tens in enumerate(tensors):
# identifies the seq_ids, which don't need to be sliced
if len(tens.shape) == 1:
new_tensors.append(tens)
else:
new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=max_idx - bucket, end=max_idx))
return new_tensors, bucket_idx.to(torch.int)