in backends/candle/src/models/jina_code.rs [91:207]
fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let device = hidden_states.device();
let residual = hidden_states.clone();
let qkv = self.qkv_linear.forward(hidden_states)?;
let mut new_qkv_shape = qkv.dims().to_vec();
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
// Split heads
let qkv = qkv.chunk(3, 2)?;
// Flatten last dims again to go through the layer norm
let query_layer = &qkv[0].flatten_from(D::Minus2)?;
let key_layer = &qkv[1].flatten_from(D::Minus2)?;
// Layer norm on q and k
let query_layer = self.layer_norm_q.forward(query_layer, None)?;
let key_layer = self.layer_norm_k.forward(key_layer, None)?;
let mut new_qk_shape = query_layer.dims().to_vec();
new_qk_shape.pop();
new_qk_shape.push(self.num_attention_heads);
new_qk_shape.push(self.attention_head_size);
let query_layer = query_layer
.reshape(new_qk_shape.as_slice())?
.transpose(1, 2)?
.contiguous()?;
let key_layer = key_layer
.reshape(new_qk_shape.as_slice())?
.transpose(1, 2)?
.contiguous()?;
let value_layer = &qkv[2].transpose(1, 2)?.contiguous()?;
#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?;
let key_layer = key_layer.flatten(0, 1)?;
let query_layer = query_layer.flatten(0, 1)?;
let value_layer = value_layer.flatten(0, 1)?;
let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?;
// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};
// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&key_layer,
&query_layer,
attention_bias.as_ref(),
Some(self.softmax_scale as f32),
beta,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let context_layer = cublaslt.batch_matmul(
&value_layer.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&query_layer),
None,
None,
None,
None,
)?;
// Reshape to dims4
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let mut attention_scores = (attention_scores * self.softmax_scale)?;
if let Some(attention_bias) = attention_bias {
attention_scores = attention_scores.add(attention_bias)?;
}
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs.matmul(&value_layer.contiguous()?)
}?;
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
let hidden_states = self.dense.forward(&context_layer)?;
let hidden_states = self
.layer_norm_out
.forward(&hidden_states, Some(&residual))?;
Ok(hidden_states)
}