def context_encoder_bk()

in optimum/neuron/models/inference/backend/modules/autobucketing.py [0:0]


def context_encoder_bk(tensors: List[torch.Tensor], buckets, padding_side: str, pad_token: int):
    """
    The Bucket Kernel for Context Encoding Models.

    1) tensors: A list of torch tensors after running through the flattener
    2) buckets: A torch.tensor of the bucket sizes
    3) padding_side: A string specifying padding side, must be "left" or "right"
    4) pad_token: An integer representing the pad token id. Typically this is 0.
    """
    input_ids = tensors[0]

    # -----Remarks for calculating position_idx-----
    # finds the number of non pad tokens and that is the active sequence_length
    # The resulting tensor is of shape (batch_size,)
    #
    # NOTE: We derive position_ids from input_ids because
    # position_ids is eliminated from the flattener for context encoding models.
    # ----------------------------------------------
    position_idx = (input_ids != pad_token).sum(dim=1)
    position_idx = position_idx[:, None]  # shape (batch_size, 1)
    buckets = buckets[None, :]  # shape (1, seq_len)

    # -----Remarks for choosing the bucket_idx-----
    # 1. (buckets < position_idx) produces a bucket_mask where invalid buckets are 0
    # 2. We convert the boolean tensor to int because argmin doesn't support
    # boolean tensors
    # 3. We choose the minimum valid bucket, which is the first 1 value
    # 4. From the minimum valid buckets, we choose the largest bucket, otherwise
    # we'd be truncating generated tokens from longer sequences.
    # 5. DO NOT USE argmax since we monkeypatch it,
    # causing issues with torch.jit.script
    # ---------------------------------------------
    bucket_mask = (buckets < position_idx).to(torch.int)  # shape (batch_size, seq_len)
    bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1))

    # select the chosen bucket after squeezing back to original form
    bucket = buckets.squeeze(0)[bucket_idx]

    new_tensors = []

    # ---------Remarks on handling padding sides-------
    # 1. slice from the opposite side for padding
    # 2. Identify seq_id tensors by shape and don't slice it
    # -------------------------------------------------
    if padding_side == "right":
        for i, tens in enumerate(tensors):
            # identifies the seq_ids, which don't need to be sliced
            if len(tens.shape) == 1:
                new_tensors.append(tens)
            else:  # all other tensors are of shape (batch_size,seq_len) so we slice on seq_len
                new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=0, end=bucket))
    else:
        max_idx = buckets[-1][-1]
        for i, tens in enumerate(tensors):
            # identifies the seq_ids, which don't need to be sliced
            if len(tens.shape) == 1:
                new_tensors.append(tens)
            else:
                new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=max_idx - bucket, end=max_idx))

    return new_tensors, bucket_idx.to(torch.int)