in models/language_model.py [0:0]
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor=None, kv_cache: list[dict]=None, start_pos: int=0):
"""
Performs a forward pass through the language model.
Args:
x (Tensor): Input tensor. If `lm_use_tokens` is True, this should be
token indices with shape (batch_size, sequence_length).
If False, it should be embeddings of shape (batch_size, sequence_length, hidden_dim).
attention_mask (Tensor, optional): Mask tensor for attention to
specify which tokens to attend to, typically of shape
(batch_size, sequence_length). Default is None.
kv_cache (list[dict], optional): List of key-value caches for each transformer
block to enable efficient autoregressive decoding.
If None, no cache is used and new ones are created. Default is None.
start_pos (int, optional): The starting position index for the current input
sequence. Used to compute rotary positional embeddings correctly,
especially for cached sequences during generation. Default is 0.
Returns:
Tuple:
- Tensor: Output logits with shape (batch_size, sequence_length, vocab_size)
if `lm_use_tokens` is True, otherwise the hidden state embeddings
(batch_size, sequence_length, hidden_dim).
- list: Updated list of key-value caches, one for each transformer block,
useful for autoregressive decoding and incremental generation.
Behavior:
- If `lm_use_tokens` is True, the input token indices are first embedded.
- Rotary positional embeddings are generated for the current input positions,
which are passed along to each transformer block.
- For each transformer block, the input is processed along with
rotary embeddings, attention mask, and optional cached key-values.
- After processing all blocks, a final RMS normalization is applied.
- If tokens are used, the normalized hidden states are projected to logits
over the vocabulary.
- The method returns the logits or embeddings along with the updated
cache for efficient decoding.
"""
if self.lm_use_tokens:
x = self.token_embedding(x)
# T_curr is the length of the current input sequence
B, T_curr, _ = x.size()
# Create position_ids for the current sequence based on start_pos
current_position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device).unsqueeze(0).expand(B, -1)
cos, sin = self.rotary_embd(current_position_ids) # Get rotary position embeddings for current tokens
# Initialize new KV cache if none provided
if kv_cache is None:
kv_cache = [None] * len(self.blocks)
for i, block in enumerate(self.blocks):
x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])
x = self.norm(x)
# Compute logits if we are using tokens, otherwise stay in the embedding space
if self.lm_use_tokens:
x = self.head(x)
return x, kv_cache