in safetensors/src/tensor.rs [1140:1186]
fn test_slicing() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let attn_0 = TensorView {
dtype: Dtype::F32,
shape: vec![1, 2, 3],
data: &data,
};
let metadata: HashMap<String, TensorView> =
[("attn.0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, None).unwrap();
let parsed = SafeTensors::deserialize(&out).unwrap();
let out_buffer: Vec<u8> = parsed
.tensor("attn.0")
.unwrap()
.slice((.., ..1))
.unwrap()
.flat_map(|b| b.to_vec())
.collect();
assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64]);
assert_eq!(
out_buffer,
vec![0.0f32, 1.0, 2.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<_>>()
);
let out_buffer: Vec<u8> = parsed
.tensor("attn.0")
.unwrap()
.slice((.., .., ..1))
.unwrap()
.flat_map(|b| b.to_vec())
.collect();
assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 64, 64]);
assert_eq!(
out_buffer,
vec![0.0f32, 3.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<_>>()
);
}