in bindings/python/src/lib.rs [1209:1252]
fn get_pydtype(module: &PyBound<'_, PyModule>, dtype: Dtype, is_numpy: bool) -> PyResult<PyObject> {
Python::with_gil(|py| {
let dtype: PyObject = match dtype {
Dtype::F64 => module.getattr(intern!(py, "float64"))?.into(),
Dtype::F32 => module.getattr(intern!(py, "float32"))?.into(),
Dtype::BF16 => {
if is_numpy {
module
.getattr(intern!(py, "dtype"))?
.call1(("bfloat16",))?
.into()
} else {
module.getattr(intern!(py, "bfloat16"))?.into()
}
}
Dtype::F16 => module.getattr(intern!(py, "float16"))?.into(),
Dtype::U64 => module.getattr(intern!(py, "uint64"))?.into(),
Dtype::I64 => module.getattr(intern!(py, "int64"))?.into(),
Dtype::U32 => module.getattr(intern!(py, "uint32"))?.into(),
Dtype::I32 => module.getattr(intern!(py, "int32"))?.into(),
Dtype::U16 => module.getattr(intern!(py, "uint16"))?.into(),
Dtype::I16 => module.getattr(intern!(py, "int16"))?.into(),
Dtype::U8 => module.getattr(intern!(py, "uint8"))?.into(),
Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(),
Dtype::BOOL => {
if is_numpy {
py.import("builtins")?.getattr(intern!(py, "bool"))?.into()
} else {
module.getattr(intern!(py, "bool"))?.into()
}
}
Dtype::F8_E4M3 => module.getattr(intern!(py, "float8_e4m3fn"))?.into(),
Dtype::F8_E5M2 => module.getattr(intern!(py, "float8_e5m2"))?.into(),
Dtype::F8_E8M0 => module.getattr(intern!(py, "float8_e8m0fnu"))?.into(),
Dtype::F4 => module.getattr(intern!(py, "float4_e2m1fn_x2"))?.into(),
dtype => {
return Err(SafetensorError::new_err(format!(
"Dtype not understood: {dtype}"
)))
}
};
Ok(dtype)
})
}