in bindings/python/src/lib.rs [50:107]
fn prepare(tensor_dict: HashMap<String, PyBound<PyDict>>) -> PyResult<HashMap<String, PyView>> {
let mut tensors = HashMap::with_capacity(tensor_dict.len());
for (tensor_name, tensor_desc) in tensor_dict {
let mut shape: Vec<usize> = tensor_desc
.get_item("shape")?
.ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc}")))?
.extract()?;
let pydata: PyBound<PyAny> = tensor_desc
.get_item("data")?
.ok_or_else(|| SafetensorError::new_err(format!("Missing `data` in {tensor_desc}")))?;
// Make sure it's extractable first.
let data: &[u8] = pydata.extract()?;
let data_len = data.len();
let data: PyBound<PyBytes> = pydata.extract()?;
let pydtype = tensor_desc
.get_item("dtype")?
.ok_or_else(|| SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc}")))?;
let dtype: String = pydtype.extract()?;
let dtype = match dtype.as_ref() {
"bool" => Dtype::BOOL,
"int8" => Dtype::I8,
"uint8" => Dtype::U8,
"int16" => Dtype::I16,
"uint16" => Dtype::U16,
"int32" => Dtype::I32,
"uint32" => Dtype::U32,
"int64" => Dtype::I64,
"uint64" => Dtype::U64,
"float16" => Dtype::F16,
"float32" => Dtype::F32,
"float64" => Dtype::F64,
"bfloat16" => Dtype::BF16,
"float8_e4m3fn" => Dtype::F8_E4M3,
"float8_e5m2" => Dtype::F8_E5M2,
"float8_e8m0fnu" => Dtype::F8_E8M0,
"float4_e2m1fn_x2" => Dtype::F4,
dtype_str => {
return Err(SafetensorError::new_err(format!(
"dtype {dtype_str} is not covered",
)));
}
};
if dtype == Dtype::F4 {
let n = shape.len();
shape[n - 1] *= 2;
}
let tensor = PyView {
shape,
dtype,
data,
data_len,
};
tensors.insert(tensor_name, tensor);
}
Ok(tensors)
}