crates/ratchet-models/src/whisper/residual_block.rs (148 lines of code) (raw):
use super::{mha::*, mlp::MLP};
use ratchet::{Device, Tensor};
use ratchet_loader::gguf::gguf::Header;
use ratchet_nn::{KVEntry, LayerNorm, Linear, Module};
use std::io::{BufRead, Seek};
#[cfg(target_arch = "wasm32")]
use {crate::ratchet_from_gguf_web, crate::TensorMap};
#[derive(Debug)]
pub struct ResidualAttentionBlock {
attn_ln: LayerNorm,
attn: MultiHeadAttention,
x_attn_ln: Option<LayerNorm>,
x_attn: Option<MultiHeadAttention>,
mlp_ln: LayerNorm,
pub mlp: MLP,
}
#[derive(Debug, derive_new::new)]
pub struct ResidualAttentionBlockInputs {
pub x: Tensor,
pub xa: Option<Tensor>,
pub mask: Option<Tensor>,
pub cache: Option<KVEntry>,
}
impl Module for ResidualAttentionBlock {
type Input = ResidualAttentionBlockInputs;
fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
let ResidualAttentionBlockInputs { x, xa, mask, cache } = input;
let attn_ln = self.attn_ln.schedule(x.clone())?;
let self_attn =
self.attn
.schedule(MHAInputs::new(attn_ln, None, mask.clone(), cache, true))?;
let mut attn = x.add(self_attn)?;
if let Some(ref xa_blck) = self.x_attn {
if let Some(xa_ln) = &self.x_attn_ln {
let x_attn_ln = xa_ln.schedule(attn.clone())?;
let x_attn =
xa_blck.schedule(MHAInputs::new(x_attn_ln, xa.clone(), None, None, false))?;
attn = x_attn.add(attn.clone())?;
}
}
let mlp_ln = self.mlp_ln.schedule(attn.clone())?;
let mlp = self.mlp.schedule(mlp_ln)?;
mlp.add(attn)
}
}
impl ResidualAttentionBlock {
pub fn load<R: BufRead + Seek>(
header: &Header,
reader: &mut R,
layer_index: usize,
n_heads: usize,
prefix: &str,
device: &Device,
) -> anyhow::Result<Self> {
let lt = |name: &str| {
let key = format!("model.{}.layers.{}.{}", prefix, layer_index, name);
header.tensor(reader, &key, device)
};
Self::load_inner(lt, prefix, n_heads)
}
#[cfg(target_arch = "wasm32")]
pub fn from_web(
header: &Header,
tensor_map: &mut TensorMap,
layer_index: usize,
n_heads: usize,
prefix: &str,
device: &Device,
) -> anyhow::Result<Self> {
let lt = |name: &str| {
let key = format!("model.{}.layers.{}.{}", prefix, layer_index, name);
let tensor = tensor_map
.remove(&key)
.ok_or_else(|| anyhow::anyhow!("missing tensor"))?;
ratchet_from_gguf_web(tensor, device)
};
Self::load_inner(lt, prefix, n_heads)
}
fn load_inner<F>(mut lt: F, prefix: &str, n_heads: usize) -> anyhow::Result<Self>
where
F: FnMut(&str) -> anyhow::Result<Tensor>,
{
let attn_ln = LayerNorm::new(
lt("self_attn_layer_norm.weight")?,
Some(lt("self_attn_layer_norm.bias")?),
1e-5,
);
//model.encoder.layers.0.self_attn.v_proj.weight
let attn = MultiHeadAttention::new(
Linear::new(
lt("self_attn.q_proj.weight")?,
Some(lt("self_attn.q_proj.bias")?),
),
Linear::new(lt("self_attn.k_proj.weight")?, None),
Linear::new(
lt("self_attn.v_proj.weight")?,
Some(lt("self_attn.v_proj.bias")?),
),
Linear::new(
lt("self_attn.out_proj.weight")?,
Some(lt("self_attn.out_proj.bias")?),
),
n_heads,
);
let (x_attn_ln, x_attn) = if prefix == "decoder" {
let x_attn_ln = LayerNorm::new(
lt("encoder_attn_layer_norm.weight")?,
Some(lt("encoder_attn_layer_norm.bias")?),
1e-5,
);
let x_attn = MultiHeadAttention::new(
Linear::new(
lt("encoder_attn.q_proj.weight")?,
Some(lt("encoder_attn.q_proj.bias")?),
),
Linear::new(lt("encoder_attn.k_proj.weight")?, None),
Linear::new(
lt("encoder_attn.v_proj.weight")?,
Some(lt("encoder_attn.v_proj.bias")?),
),
Linear::new(
lt("encoder_attn.out_proj.weight")?,
Some(lt("encoder_attn.out_proj.bias")?),
),
n_heads,
);
(Some(x_attn_ln), Some(x_attn))
} else {
(None, None)
};
let mlp_ln = LayerNorm::new(
lt("final_layer_norm.weight")?,
Some(lt("final_layer_norm.bias")?),
1e-5,
);
let mlp = MLP::new(
Linear::new(lt("fc1.weight")?, Some(lt("fc1.bias")?)),
Linear::new(lt("fc2.weight")?, Some(lt("fc2.bias")?)),
);
Ok(Self {
attn_ln,
attn,
x_attn_ln,
x_attn,
mlp_ln,
mlp,
})
}
}