in safetensors/src/tensor.rs [1198:1261]
fn gpt2_like(n_heads: usize, model_id: &str) {
let mut tensors_desc = vec![
("wte".to_string(), vec![50257, 768]),
("wpe".to_string(), vec![1024, 768]),
];
for i in 0..n_heads {
tensors_desc.push((format!("h.{i}.ln_1.weight"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_1.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.attn.bias"), vec![1, 1, 1024, 1024]));
tensors_desc.push((format!("h.{i}.attn.c_attn.weight"), vec![768, 2304]));
tensors_desc.push((format!("h.{i}.attn.c_attn.bias"), vec![2304]));
tensors_desc.push((format!("h.{i}.attn.c_proj.weight"), vec![768, 768]));
tensors_desc.push((format!("h.{i}.attn.c_proj.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_2.weight"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_2.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.mlp.c_fc.weight"), vec![768, 3072]));
tensors_desc.push((format!("h.{i}.mlp.c_fc.bias"), vec![3072]));
tensors_desc.push((format!("h.{i}.mlp.c_proj.weight"), vec![3072, 768]));
tensors_desc.push((format!("h.{i}.mlp.c_proj.bias"), vec![768]));
}
tensors_desc.push(("ln_f.weight".to_string(), vec![768]));
tensors_desc.push(("ln_f.bias".to_string(), vec![768]));
let dtype = Dtype::F32;
let nbits: usize = tensors_desc
.iter()
.map(|(_, shape)| shape.iter().product::<usize>())
.sum::<usize>()
* dtype.bitsize();
if nbits % 8 != 0 {
panic!("Misaligned slice");
}
let n = nbits
.checked_div(8)
.ok_or(SafeTensorError::ValidationOverflow)
.unwrap(); // 4
let all_data = vec![0; n];
let mut metadata = HashMap::with_capacity(tensors_desc.len());
let mut offset = 0;
for (name, shape) in tensors_desc {
let n: usize = shape.iter().product();
let buffer = &all_data[offset..offset + (n * dtype.bitsize()) / 8];
let tensor = TensorView::new(dtype, shape, buffer).unwrap();
metadata.insert(name, tensor);
offset += n;
}
let filename = format!("./out_{model_id}.safetensors");
let out = serialize(&metadata, None).unwrap();
std::fs::write(&filename, out).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
// File api
#[cfg(feature = "std")]
{
serialize_to_file(&metadata, None, std::path::Path::new(&filename)).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
}
}