in models/language_model.py [0:0]
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attention_mask: torch.Tensor=None, block_kv_cache: dict=None):
"""
Forward pass of the Transformer block.
Args:
x (Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).
cos (Tensor): Cosine positional embeddings for rotary embedding, shape
matching sequence length and head dimension.
sin (Tensor): Sine positional embeddings for rotary embedding, same shape as cos.
attention_mask (Tensor, optional): Attention mask of shape (batch_size, total_kv_length),
with 1 indicating tokens to attend to and 0 for padding tokens.
block_kv_cache (dict, optional): Key-value cache dict for cached keys and values
during decoding. If None, no cache is used.
Returns:
Tuple[Tensor, dict]: Output tensor after the block (same shape as input),
and the updated key-value cache dictionary.
"""
res = x
x = self.norm1(x)
x, block_kv_cache = self.attn(x, cos, sin, attention_mask, block_kv_cache)
x = res + x
res = x
x = self.norm2(x)
x = self.mlp(x)
x = res + x
return x, block_kv_cache