in models/language_model.py [0:0]
def generate(self, inputs: torch.Tensor, max_new_tokens: int=20):
"""
Generate tokens autoregressively from a given input sequence.
Args:
inputs (torch.Tensor): Input tensor containing token indices or embeddings.
Shape: (batch_size, sequence_length) or (sequence_length,) for a single sequence.
max_new_tokens (int): Number of new tokens to generate after the input sequence.
Returns:
torch.Tensor: The generated sequence, including the original inputs and newly generated tokens.
Shape: (batch_size, sequence_length + max_new_tokens)
"""
# Add batch dimension if needed
if inputs.dim() == 1:
inputs = inputs.unsqueeze(0)
generated_outputs = inputs.clone()
prompt_output, kv_cache_list = self.forward(
generated_outputs,
attention_mask=None,
kv_cache=None,
start_pos=0
)
last_output = prompt_output[:, -1, :]
# Decode Phase with KV cache
for i in range(max_new_tokens):
if self.lm_use_tokens:
# Now the model outputs logits
next_output = torch.argmax(last_output, dim=-1, keepdim=True)
else:
# Now the model outputs embeddings
next_output = last_output.unsqueeze(1)
generated_outputs = torch.cat((generated_outputs, next_output), dim=1)
# The token being processed is `next_token`. Its position is `generated_outputs.size(1) - 1`.
current_token_start_pos = generated_outputs.size(1) - 1
if i == max_new_tokens - 1:
break
decode_step_output, kv_cache_list = self.forward(
next_output,
attention_mask=None,
kv_cache=kv_cache_list,
start_pos=current_token_start_pos
)
last_output = decode_step_output[:, -1, :]
return generated_outputs