in bindings/python/src/lib.rs [434:536]
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
let file = File::open(&filename).map_err(|_| {
PyFileNotFoundError::new_err(format!(
"No such file or directory: {}",
filename.display()
))
})?;
let device = device.unwrap_or(Device::Cpu);
if device != Device::Cpu && framework != Framework::Pytorch {
return Err(SafetensorError::new_err(format!(
"Device {device} is not supported for framework {framework}",
)));
}
// SAFETY: Mmap is used to prevent allocating in Rust
// before making a copy within Python.
let buffer = unsafe { MmapOptions::new().map_copy_read_only(&file)? };
let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
SafetensorError::new_err(format!("Error while deserializing header: {e}"))
})?;
let offset = n + 8;
Python::with_gil(|py| -> PyResult<()> {
match framework {
Framework::Pytorch => {
let module = PyModule::import(py, intern!(py, "torch"))?;
TORCH_MODULE.get_or_init_py_attached(py, || module.into())
}
_ => {
let module = PyModule::import(py, intern!(py, "numpy"))?;
NUMPY_MODULE.get_or_init_py_attached(py, || module.into())
}
};
Ok(())
})?;
let storage = match &framework {
Framework::Pytorch => Python::with_gil(|py| -> PyResult<Storage> {
let module = get_module(py, &TORCH_MODULE)?;
let version: String = module.getattr(intern!(py, "__version__"))?.extract()?;
let version = Version::from_string(&version).map_err(SafetensorError::new_err)?;
// Untyped storage only exists for versions over 1.11.0
// Same for torch.asarray which is necessary for zero-copy tensor
if version >= Version::new(1, 11, 0) {
// storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
let py_filename: PyObject = filename
.to_str()
.ok_or_else(|| {
SafetensorError::new_err(format!(
"Path {} is not valid UTF-8",
filename.display()
))
})?
.into_pyobject(py)?
.into();
let size: PyObject = buffer.len().into_pyobject(py)?.into();
let shared: PyObject = PyBool::new(py, false).to_owned().into();
let (size_name, storage_name) = if version >= Version::new(2, 0, 0) {
(intern!(py, "nbytes"), intern!(py, "UntypedStorage"))
} else {
(intern!(py, "size"), intern!(py, "ByteStorage"))
};
let kwargs =
[(intern!(py, "shared"), shared), (size_name, size)].into_py_dict(py)?;
let storage = module
.getattr(storage_name)?
// .getattr(intern!(py, "from_file"))?
.call_method("from_file", (py_filename,), Some(&kwargs))?;
let untyped: PyBound<'_, PyAny> = match storage.getattr(intern!(py, "untyped"))
{
Ok(untyped) => untyped,
Err(_) => storage.getattr(intern!(py, "_untyped"))?,
};
let storage = untyped.call0()?.into_pyobject(py)?.into();
let gil_storage = OnceLock::new();
gil_storage.get_or_init_py_attached(py, || storage);
Ok(Storage::TorchStorage(gil_storage))
} else {
Ok(Storage::Mmap(buffer))
}
})?,
_ => Storage::Mmap(buffer),
};
let storage = Arc::new(storage);
Ok(Self {
metadata,
offset,
framework,
device,
storage,
})
}