in torchrec/distributed/model_parallel.py [0:0]
def _init_parameters(self, module: nn.Module) -> None:
@torch.no_grad()
def init_parameters(module: nn.Module) -> None:
# Allocate parameters and buffers if over 'meta' device.
has_meta_param = False
for name, param in module._parameters.items():
if isinstance(param, torch.Tensor) and param.device.type == "meta":
module._parameters[name] = nn.Parameter(
torch.empty_like(param, device=self.device),
requires_grad=param.requires_grad,
)
has_meta_param = True
for name, buffer in module._buffers.items():
if isinstance(buffer, torch.Tensor) and buffer.device.type == "meta":
module._buffers[name] = torch.empty_like(buffer, device=self.device)
# Init parameters if at least one parameter is over 'meta' device.
if has_meta_param and hasattr(module, "reset_parameters"):
# pyre-ignore [29]
module.reset_parameters()
module.apply(init_parameters)