in optimum/tpu/modeling_llama.py [0:0]
def load_hook(self, state_dict, _prefix, *_args):
num_attn_heads = self.config.num_attention_heads
head_dim = self.config.hidden_size // self.config.num_attention_heads
hidden_size = self.config.hidden_size
def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
axis_len = tensor.shape[axis]
split_len = axis_len // self.world_size
split_start = split_len * self.rank
split_end = split_start + split_len
tensor = torch.moveaxis(tensor, axis, 0)
tensor = tensor[split_start:split_end, ...]
tensor = torch.moveaxis(tensor, 0, axis)
return tensor
for k, v in state_dict.items():
if re.fullmatch(r"model.layers.\d+.mlp.(gate_proj|up_proj).weight", k):
v = split(v, 0)
if re.fullmatch(r"model.layers.\d+.mlp.down_proj.weight", k):
v = split(v, 1)
if re.fullmatch(r"model.layers.\d+.self_attn.(k|v)_proj.weight", k):
v = split(v, 0)
if re.fullmatch(r"model.layers.\d+.self_attn.q_proj.weight", k):
v = v.reshape(num_attn_heads, head_dim, hidden_size)
v = split(v, 0)
v = v.reshape(-1, hidden_size)
if re.fullmatch(r"model.layers.\d+.self_attn.o_proj.weight", k):
v = v.reshape(hidden_size, num_attn_heads, head_dim)
v = split(v, 1)
v = v.reshape(hidden_size, -1)
if k == "lm_head.weight":
v = split(v, 0)
# Update state_dict
state_dict[k] = v