in optimum/exporters/executorch/integrations.py [0:0]
def generate(self, prompt_token_ids, max_new_tokens):
with torch.no_grad():
# Run encoder
encoder_output = self.exported_encoder.module()(prompt_token_ids)
# Initialize with start token (0 for T5)
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
generated_ids = [0]
# Generate tokens one by one
for i in range(max_new_tokens - 1):
# Run decoder for next token prediction
logits = self.exported_decoder.module()(
decoder_input_ids,
encoder_output,
torch.tensor([i], dtype=torch.long),
)
# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)
# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)
# Check if EOS token
if next_token == self.config.eos_token_id:
break
return generated_ids