in backends/candle/src/models/distilbert.rs [113:203]
fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let device = hidden_states.device();
let qkv = self.qkv_linear.forward(hidden_states)?;
let mut new_qkv_shape = qkv.dims().to_vec();
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let qkv = qkv.chunk(3, 1)?;
let query_layer = &qkv[0].contiguous()?;
let key_layer = &qkv[1].contiguous()?;
let value_layer = &qkv[2];
#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?;
let key_layer = key_layer.flatten(0, 1)?;
let query_layer = query_layer.flatten(0, 1)?;
let value_layer = value_layer.flatten(0, 1)?;
let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?;
// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};
// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&key_layer,
&query_layer,
attention_bias.as_ref(),
Some(self.softmax_scale as f32),
beta,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let context_layer = cublaslt.batch_matmul(
&value_layer.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&query_layer),
None,
None,
None,
None,
)?;
// Reshape to dims4
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let mut attention_scores = (attention_scores * self.softmax_scale)?;
if let Some(attention_bias) = attention_bias {
attention_scores = attention_scores.add(attention_bias)?;
}
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs.matmul(&value_layer.contiguous()?)
}?;
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
let hidden_states = self.dense.forward(&context_layer)?;
Ok(hidden_states)
}