optimum/neuron/models/inference/phi3/modeling_phi3.py (40 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Phi3 model for NXD inference.""" import gc import logging import torch from transformers.models.phi3.modeling_phi3 import Phi3Config from ..backend.config import NxDNeuronConfig # noqa: E402 from ..llama.modeling_llama import ( LlamaNxDModelForCausalLM, NxDLlamaModel, ) logger = logging.getLogger("Neuron") class Phi3NxDModelForCausalLM(LlamaNxDModelForCausalLM): """ Phi3 model for NxD inference. This class inherits from the Neuron Llama model class since the Phi3 modeling is just a Llama modeling with fused qkv and mlp projections. The state_dict loading method is modified to unfuse the weights at loading time. """ _model_cls = NxDLlamaModel @staticmethod def convert_hf_to_neuron_state_dict(state_dict: dict, config: Phi3Config, neuron_config: NxDNeuronConfig) -> dict: for l in range(config.num_hidden_layers): # noqa: E741 if neuron_config.fused_qkv: # Just rename the qkv projection to the expected name state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = state_dict.pop( f"layers.{l}.self_attn.qkv_proj.weight" ) else: # Unfuse the qkv projections as expected by the NeuronAttentionBase fused_qkv = state_dict[f"layers.{l}.self_attn.qkv_proj.weight"] state_dict[f"layers.{l}.self_attn.q_proj.weight"] = fused_qkv[: config.hidden_size, :].clone().detach() k_weight, v_weight = torch.chunk(fused_qkv[config.hidden_size :, :], 2, dim=0) state_dict[f"layers.{l}.self_attn.k_proj.weight"] = k_weight.clone().detach() state_dict[f"layers.{l}.self_attn.v_proj.weight"] = v_weight.clone().detach() state_dict.pop(f"layers.{l}.self_attn.qkv_proj.weight") gc.collect() # Unfuse the mlp projections gate_weight, up_weight = torch.chunk(state_dict[f"layers.{l}.mlp.gate_up_proj.weight"], 2, dim=0) state_dict[f"layers.{l}.mlp.gate_proj.weight"] = gate_weight.clone().detach() state_dict[f"layers.{l}.mlp.up_proj.weight"] = up_weight.clone().detach() state_dict.pop(f"layers.{l}.mlp.gate_up_proj.weight") gc.collect() if neuron_config.vocab_parallel: # TODO: this hack can be removed after replication_id is ready to use state_dict["embed_tokens.rank_util.rank"] = torch.arange(0, neuron_config.local_ranks_size) # to facilitate rank usage in attention num_layers = config.num_hidden_layers tp_degree = neuron_config.tp_degree for i in range(num_layers): state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) # to facilitate rank usage in base model state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) return state_dict