in src/accelerate/hooks.py [0:0]
def init_hook(self, module):
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
self.tied_params_map = None
if not self.offload and self.execution_device is not None:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
elif self.offload:
self.original_devices = {
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
}
if self.weights_map is None:
self.weights_map = {
name: param.to("cpu")
for name, param in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
)
}
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
):
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
if (
self.tied_params_map is not None
and recursive_getattr(module, name).data_ptr() in self.tied_params_map
):
self.tied_params_names.add(name)
set_module_tensor_to_device(module, name, "meta")
if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(
module, name, self.execution_device, tied_params_map=self.tied_params_map
)
elif self.offload_buffers and self.execution_device is not None:
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
set_module_tensor_to_device(
module, name, self.execution_device, tied_params_map=self.tied_params_map
)
return module