crates/ratchet-nn/src/embedding.rs (97 lines of code) (raw):
use crate::Module;
use ratchet::{shape, Tensor};
/// # Embedding
///
/// Standard `torch.nn.Embedding` module.
#[derive(Debug, derive_new::new)]
pub struct Embedding {
pub weight: Tensor,
}
impl Module for Embedding {
type Input = Tensor;
fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
let mut output_shape = input.shape().clone();
let weight_rank = self.weight.rank();
let weight_dim = weight_rank - 1;
output_shape.push(self.weight.shape()[weight_dim]);
let flat_shape = shape![input.shape().numel()];
let flat = input.view(flat_shape)?;
let indexed = self.weight.clone().index_select(flat, 0)?;
indexed.view(output_shape)
}
}
#[cfg(all(test, feature = "pyo3"))]
mod tests {
use hf_hub::api::sync::Api;
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use ratchet_loader::gguf::gguf::Header;
use test_strategy::proptest;
use tokenizers::Tokenizer;
use ratchet::test_util::run_py_prg;
use ratchet::{rvec, shape, Device, DeviceRequest, Shape, Tensor};
use crate::{Embedding, Module};
impl Arbitrary for EmbeddingProblem {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
{
let args = vec![1..512usize, 1..16usize];
args.prop_map(Into::<Shape>::into).boxed()
}
.prop_flat_map(|vocab_shape| (Just(vocab_shape), 1..64usize))
.prop_map(|(vocab_shape, num_indices)| {
let indices =
Tensor::randint(0, vocab_shape[0] as i32, shape![num_indices], Device::CPU);
EmbeddingProblem {
vocab_shape,
indices,
}
})
.boxed()
}
}
fn ground_truth(weight: &Tensor, indices: &Tensor) -> anyhow::Result<Tensor> {
let arg = "torch.from_numpy(weight)";
let prg = format!(
r#"
import torch
def embedding(weight, indices):
embedding = torch.nn.Embedding.from_pretrained({})
return embedding(torch.from_numpy(indices)).numpy()
"#,
arg
);
run_py_prg(prg.to_string(), &[weight, indices], &[], weight.dt())
}
fn run_embedding_trial(problem: EmbeddingProblem) {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
println!("Embedding problem: {:?}", problem);
let EmbeddingProblem {
vocab_shape,
indices,
} = problem;
let weight = Tensor::randn::<f32>(vocab_shape, Device::CPU);
let ground_truth = ground_truth(&weight, &indices).unwrap();
let weight = weight.to(&device).unwrap();
let indices = indices.to(&device).unwrap();
let embedding = Embedding::new(weight);
let result = embedding.schedule(indices).unwrap().resolve().unwrap();
let x = result.to(&Device::CPU).unwrap();
ground_truth.all_close(&x, 1e-6, 1e-6).unwrap();
}
#[derive(Debug, Clone)]
struct EmbeddingProblem {
vocab_shape: Shape,
indices: Tensor,
}
#[test]
fn debug_embedding() {
let prob = EmbeddingProblem {
vocab_shape: shape![10000, 384],
indices: Tensor::from_data([400i32, 9001i32, 5555i32], shape![1, 3], Device::CPU),
};
run_embedding_trial(prob);
}
#[proptest(cases = 16)]
fn test_embedding(prob: EmbeddingProblem) {
run_embedding_trial(prob);
}
}