in bindings/python/src/lib.rs [1087:1207]
fn create_tensor<'a>(
framework: &'a Framework,
dtype: Dtype,
shape: &'a [usize],
array: PyObject,
device: &'a Device,
) -> PyResult<PyObject> {
Python::with_gil(|py| -> PyResult<PyObject> {
let (module, is_numpy): (&PyBound<'_, PyModule>, bool) = match framework {
Framework::Pytorch => (
TORCH_MODULE
.get()
.ok_or_else(|| {
SafetensorError::new_err(format!("Could not find module {framework}",))
})?
.bind(py),
false,
),
frame => {
// Attempt to load the frameworks
// Those are needed to prepare the ml dtypes
// like bfloat16
match frame {
Framework::Tensorflow => {
let _ = PyModule::import(py, intern!(py, "tensorflow"));
}
Framework::Flax => {
let _ = PyModule::import(py, intern!(py, "flax"));
}
_ => {}
};
(
NUMPY_MODULE
.get()
.ok_or_else(|| {
SafetensorError::new_err(format!("Could not find module {framework}",))
})?
.bind(py),
true,
)
}
};
let dtype: PyObject = get_pydtype(module, dtype, is_numpy)?;
let count: usize = shape.iter().product();
let shape = shape.to_vec();
let tensor = if count == 0 {
// Torch==1.10 does not allow frombuffer on empty buffers so we create
// the tensor manually.
// let zeros = module.getattr(intern!(py, "zeros"))?;
let shape: PyObject = shape.clone().into_pyobject(py)?.into();
let args = (shape,);
let kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict(py)?;
module.call_method("zeros", args, Some(&kwargs))?
} else {
// let frombuffer = module.getattr(intern!(py, "frombuffer"))?;
let kwargs = [
(intern!(py, "buffer"), array),
(intern!(py, "dtype"), dtype),
]
.into_py_dict(py)?;
let mut tensor = module.call_method("frombuffer", (), Some(&kwargs))?;
let sys = PyModule::import(py, intern!(py, "sys"))?;
let byteorder: String = sys.getattr(intern!(py, "byteorder"))?.extract()?;
if byteorder == "big" {
let inplace_kwargs =
[(intern!(py, "inplace"), PyBool::new(py, false))].into_py_dict(py)?;
tensor = tensor
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
}
tensor
};
let mut tensor: PyBound<'_, PyAny> = tensor.call_method1("reshape", (shape,))?;
let tensor = match framework {
Framework::Flax => {
let module = Python::with_gil(|py| -> PyResult<&Py<PyModule>> {
let module = PyModule::import(py, intern!(py, "jax"))?;
Ok(FLAX_MODULE.get_or_init_py_attached(py, || module.into()))
})?
.bind(py);
module
.getattr(intern!(py, "numpy"))?
.getattr(intern!(py, "array"))?
.call1((tensor,))?
}
Framework::Tensorflow => {
let module = Python::with_gil(|py| -> PyResult<&Py<PyModule>> {
let module = PyModule::import(py, intern!(py, "tensorflow"))?;
Ok(TENSORFLOW_MODULE.get_or_init_py_attached(py, || module.into()))
})?
.bind(py);
module
.getattr(intern!(py, "convert_to_tensor"))?
.call1((tensor,))?
}
Framework::Mlx => {
let module = Python::with_gil(|py| -> PyResult<&Py<PyModule>> {
let module = PyModule::import(py, intern!(py, "mlx"))?;
Ok(MLX_MODULE.get_or_init_py_attached(py, || module.into()))
})?
.bind(py);
module
.getattr(intern!(py, "core"))?
// .getattr(intern!(py, "array"))?
.call_method1("array", (tensor,))?
}
Framework::Pytorch => {
if device != &Device::Cpu {
let device: PyObject = device.clone().into_pyobject(py)?.into();
let kwargs = PyDict::new(py);
tensor = tensor.call_method("to", (device,), Some(&kwargs))?;
}
tensor
}
Framework::Numpy => tensor,
};
// let tensor = tensor.into_py_bound(py);
Ok(tensor.into())
})
}