def _forward_with_pad()

in optimum/neuron/models/inference/backend/modules/decoder/decoder_wrapper.py [0:0]


    def _forward_with_pad(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params):
        # pad the inputs up to the compiled batch size in the end
        def pad_helper(tensor, pad_type="zeros"):
            VALID_PAD_TYPES = {"zeros", "ones", "repeat_first_batchline"}
            assert pad_type in VALID_PAD_TYPES, f"Found {pad_type=}, but valid pad types are {VALID_PAD_TYPES}"
            if tensor is None or tensor.shape[0] == self.neuron_config.batch_size:
                return tensor

            padded_shape = list(tensor.shape)
            padded_shape[0] = self.neuron_config.batch_size
            if pad_type == "repeat_first_batchline":
                # pad with first batch line values instead of zeros, to reduce chances of NaN
                padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype)
            else:
                fill_value = 0 if pad_type == "zeros" else 1
                padded_tensor = torch.full(padded_shape, fill_value=fill_value, dtype=tensor.dtype)
            padded_tensor[: tensor.shape[0]] = tensor
            return padded_tensor

        padded_args = []
        for arg in (input_ids, attention_mask, position_ids):
            padded_args.append(pad_helper(arg, pad_type="repeat_first_batchline"))

        # need to handle seq_ids separately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1,
        # 0]. then the kv cache of padded input could be written into the first cache line, so we need to pad as [0,
        # 2, 1, 3] instead

        seq_ids_list = seq_ids.tolist()
        padded_seq_ids = torch.tensor(
            seq_ids_list + [x for x in range(self.neuron_config.max_batch_size) if x not in seq_ids_list],
            dtype=seq_ids.dtype,
        )
        padded_args.append(padded_seq_ids)

        # pad sampling params by repeating first batchline
        padded_sampling_params = pad_helper(sampling_params, pad_type="repeat_first_batchline")
        padded_args.append(padded_sampling_params)

        outputs = self._forward(*padded_args)

        # note that we don't do index select here as it should already be handled, simply sliced out padding here
        logits = outputs
        return logits[: seq_ids.shape[0]]