fn prepare()

in bindings/python/src/lib.rs [50:107]


fn prepare(tensor_dict: HashMap<String, PyBound<PyDict>>) -> PyResult<HashMap<String, PyView>> {
    let mut tensors = HashMap::with_capacity(tensor_dict.len());
    for (tensor_name, tensor_desc) in tensor_dict {
        let mut shape: Vec<usize> = tensor_desc
            .get_item("shape")?
            .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc}")))?
            .extract()?;
        let pydata: PyBound<PyAny> = tensor_desc
            .get_item("data")?
            .ok_or_else(|| SafetensorError::new_err(format!("Missing `data` in {tensor_desc}")))?;
        // Make sure it's extractable first.
        let data: &[u8] = pydata.extract()?;
        let data_len = data.len();
        let data: PyBound<PyBytes> = pydata.extract()?;
        let pydtype = tensor_desc
            .get_item("dtype")?
            .ok_or_else(|| SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc}")))?;
        let dtype: String = pydtype.extract()?;
        let dtype = match dtype.as_ref() {
            "bool" => Dtype::BOOL,
            "int8" => Dtype::I8,
            "uint8" => Dtype::U8,
            "int16" => Dtype::I16,
            "uint16" => Dtype::U16,
            "int32" => Dtype::I32,
            "uint32" => Dtype::U32,
            "int64" => Dtype::I64,
            "uint64" => Dtype::U64,
            "float16" => Dtype::F16,
            "float32" => Dtype::F32,
            "float64" => Dtype::F64,
            "bfloat16" => Dtype::BF16,
            "float8_e4m3fn" => Dtype::F8_E4M3,
            "float8_e5m2" => Dtype::F8_E5M2,
            "float8_e8m0fnu" => Dtype::F8_E8M0,
            "float4_e2m1fn_x2" => Dtype::F4,
            dtype_str => {
                return Err(SafetensorError::new_err(format!(
                    "dtype {dtype_str} is not covered",
                )));
            }
        };

        if dtype == Dtype::F4 {
            let n = shape.len();
            shape[n - 1] *= 2;
        }

        let tensor = PyView {
            shape,
            dtype,
            data,
            data_len,
        };
        tensors.insert(tensor_name, tensor);
    }
    Ok(tensors)
}