in picotron/checkpoint.py [0:0]
def init_model_with_dematerialized_weights(include_buffers: bool = False):
"""
From Accelerate library: https://github.com/huggingface/accelerate/blob/v0.11.0/src/accelerate/big_modeling.py#L254
Context manager that initializes models with empty weights (no memory allocation).
Args:
include_buffers (bool): Whether to also skip buffer initialization.
"""
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(module._parameters[name].to(torch.device("meta")), **kwargs)
def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer