in candle-transformers/src/models/t5.rs [450:577]
fn forward(
&mut self,
xs: &Tensor,
position_bias: Option<&Tensor>,
key_value_states: Option<&Tensor>,
mask: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
// Performs Self-attention (if key_value_states is None) or attention
// over source sentence (provided by key_value_states).
let _enter = self.span.enter();
let kv_input = match key_value_states {
None => xs,
Some(key_value_states) => key_value_states,
};
let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
let kv_len = kv_input.dim(1)?;
let q = self.q.forward(xs)?;
let k = self.k.forward(kv_input)?;
let v = self.v.forward(kv_input)?;
let q = q
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let mut k = k
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
if self.use_cache && key_value_states.is_none() {
let _enter = self.span_cache.enter();
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
k = Tensor::cat(&[kv_cache_k, &k], 2)?;
v = Tensor::cat(&[kv_cache_v, &v], 2)?;
};
self.kv_cache = Some((k.clone(), v.clone()));
};
let k = k.contiguous()?;
let v = v.contiguous()?;
// TODO: Use flash_attn.
let scores = {
let _enter = self.span_mm.enter();
q.matmul(&k.t()?)?
};
let scores = match mask {
None => scores,
Some(mask) => masked_fill(
&scores,
&mask
.unsqueeze(0)?
.unsqueeze(0)?
.repeat((b_sz, self.n_heads))?,
f32::NEG_INFINITY,
)?,
};
let (scores, position_bias) = match position_bias {
Some(position_bias) => (
scores.broadcast_add(position_bias)?,
Some(position_bias.clone()),
),
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => {
// This only handles the bidirectional case.
let kv_len = k.dim(2)?;
let (q_start, q_end) = match self.use_cache {
true => ((kv_len - q_len) as u32, kv_len as u32),
false => (0_u32, kv_len as u32),
};
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (q_start..q_end)
.map(|i| {
(0..kv_len as u32)
.map(|j| {
if i < j {
if j - i < max_exact {
j - i + num_buckets
} else {
let b = f32::log(
(j - i) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(
max_exact + num_buckets + b as u32,
self.relative_attention_num_buckets as u32 - 1,
)
}
} else if i - j < max_exact {
i - j
} else {
let b = f32::log(
(i - j) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(max_exact + b as u32, num_buckets - 1)
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?
.to_dtype(scores.dtype())?;
(scores.broadcast_add(&position_bias)?, Some(position_bias))
// TODO: position_bias_masked?
}
},
};
let attn_weights = {
let _enter = self.span_sm.enter();
candle_nn::ops::softmax_last_dim(&scores)?
};
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias))
}