tzrec/utils/state_dict_util.py (35 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from torchrec.modules.mc_modules import MCHManagedCollisionModule def fix_mch_state(model: nn.Module) -> None: """Fix output_segments_tensor of mc modules may be a meta tensor.""" for _, m in model.named_modules(): # fix output_segments_tensor is a meta tensor. if ( isinstance(m, MCHManagedCollisionModule) # pyre-ignore [16] and m._buffers["_output_segments_tensor"].is_meta ): output_segments = [ m._output_global_offset, m._output_global_offset + m._zch_size, ] m._buffers["_output_segments_tensor"] = torch.tensor( output_segments + [-1] * (1025 - len(output_segments)), dtype=torch.int64, device=m._current_iter_tensor.device, ) def init_parameters(module: nn.Module, device: torch.device) -> None: """Init param for model with meta device type.""" @torch.no_grad() def init_parameters(module: nn.Module) -> None: # Allocate parameters and buffers if over 'meta' device. has_meta_param = False for name, param in module._parameters.items(): if isinstance(param, torch.Tensor) and param.device.type == "meta": module._parameters[name] = nn.Parameter( torch.empty_like(param, device=device), requires_grad=param.requires_grad, ) has_meta_param = True for name, buffer in module._buffers.items(): if isinstance(buffer, torch.Tensor) and buffer.device.type == "meta": module._buffers[name] = torch.zeros_like(buffer, device=device) # Init parameters if at least one parameter is over 'meta' device. if has_meta_param and hasattr(module, "reset_parameters"): module.reset_parameters() module.apply(init_parameters)