in optimum/neuron/models/inference/backend/modules/decoder/decoder_wrapper.py [0:0]
def _forward_with_pad(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params):
# pad the inputs up to the compiled batch size in the end
def pad_helper(tensor, pad_type="zeros"):
VALID_PAD_TYPES = {"zeros", "ones", "repeat_first_batchline"}
assert pad_type in VALID_PAD_TYPES, f"Found {pad_type=}, but valid pad types are {VALID_PAD_TYPES}"
if tensor is None or tensor.shape[0] == self.neuron_config.batch_size:
return tensor
padded_shape = list(tensor.shape)
padded_shape[0] = self.neuron_config.batch_size
if pad_type == "repeat_first_batchline":
# pad with first batch line values instead of zeros, to reduce chances of NaN
padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype)
else:
fill_value = 0 if pad_type == "zeros" else 1
padded_tensor = torch.full(padded_shape, fill_value=fill_value, dtype=tensor.dtype)
padded_tensor[: tensor.shape[0]] = tensor
return padded_tensor
padded_args = []
for arg in (input_ids, attention_mask, position_ids):
padded_args.append(pad_helper(arg, pad_type="repeat_first_batchline"))
# need to handle seq_ids separately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1,
# 0]. then the kv cache of padded input could be written into the first cache line, so we need to pad as [0,
# 2, 1, 3] instead
seq_ids_list = seq_ids.tolist()
padded_seq_ids = torch.tensor(
seq_ids_list + [x for x in range(self.neuron_config.max_batch_size) if x not in seq_ids_list],
dtype=seq_ids.dtype,
)
padded_args.append(padded_seq_ids)
# pad sampling params by repeating first batchline
padded_sampling_params = pad_helper(sampling_params, pad_type="repeat_first_batchline")
padded_args.append(padded_sampling_params)
outputs = self._forward(*padded_args)
# note that we don't do index select here as it should already be handled, simply sliced out padding here
logits = outputs
return logits[: seq_ids.shape[0]]