def prepare_inputs_for_generation()

in jat/modeling_jat.py [0:0]


    def prepare_inputs_for_generation(self, input_ids, pixel_values=None, past_key_values=None, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if past_key_values is not None:
            pixel_values = None
            input_ids = input_ids[:, -1].unsqueeze(-1)

        model_inputs = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
        }

        return model_inputs