in jat/modeling_jat.py [0:0]
def prepare_inputs_for_generation(self, input_ids, pixel_values=None, past_key_values=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
pixel_values = None
input_ids = input_ids[:, -1].unsqueeze(-1)
model_inputs = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
return model_inputs