maga_transformer/tools/fake_util.py (20 lines of code) (raw):

import torch from typing import Dict, List def generate_fake_model(shape_map: Dict[str, List[int]]): fake_model: Dict[str, torch.Tensor] = {} for key, shape in shape_map.items(): print(f"generate tensor: {key}, shape: {shape}") fake_model[key] = torch.rand(shape, dtype=torch.half) return fake_model def copy_from_model(shape_map: Dict[str, List[int]], model: Dict[str, torch.Tensor]): copy_model: Dict[str, torch.Tensor] = {} for key, shape in shape_map.items(): print("key = ", key) print(f"copy tensor {key}, origin shape: {model[key].shape}, copy shape: {shape}") copy_model[key] = copy_tensor(model[key], shape) copy_model[key].contiguous() return copy_model def copy_tensor(x: torch.Tensor, shape: List[int]) -> torch.Tensor: for i, dim_size in enumerate(shape): x = x.narrow(i, 0, dim_size) return torch.clone(x)