optimum/neuron/models/inference/phi3/modeling_phi3.py (40 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Phi3 model for NXD inference."""
import gc
import logging
import torch
from transformers.models.phi3.modeling_phi3 import Phi3Config
from ..backend.config import NxDNeuronConfig # noqa: E402
from ..llama.modeling_llama import (
LlamaNxDModelForCausalLM,
NxDLlamaModel,
)
logger = logging.getLogger("Neuron")
class Phi3NxDModelForCausalLM(LlamaNxDModelForCausalLM):
"""
Phi3 model for NxD inference.
This class inherits from the Neuron Llama model class since the Phi3 modeling is just a
Llama modeling with fused qkv and mlp projections.
The state_dict loading method is modified to unfuse the weights at loading time.
"""
_model_cls = NxDLlamaModel
@staticmethod
def convert_hf_to_neuron_state_dict(state_dict: dict, config: Phi3Config, neuron_config: NxDNeuronConfig) -> dict:
for l in range(config.num_hidden_layers): # noqa: E741
if neuron_config.fused_qkv:
# Just rename the qkv projection to the expected name
state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = state_dict.pop(
f"layers.{l}.self_attn.qkv_proj.weight"
)
else:
# Unfuse the qkv projections as expected by the NeuronAttentionBase
fused_qkv = state_dict[f"layers.{l}.self_attn.qkv_proj.weight"]
state_dict[f"layers.{l}.self_attn.q_proj.weight"] = fused_qkv[: config.hidden_size, :].clone().detach()
k_weight, v_weight = torch.chunk(fused_qkv[config.hidden_size :, :], 2, dim=0)
state_dict[f"layers.{l}.self_attn.k_proj.weight"] = k_weight.clone().detach()
state_dict[f"layers.{l}.self_attn.v_proj.weight"] = v_weight.clone().detach()
state_dict.pop(f"layers.{l}.self_attn.qkv_proj.weight")
gc.collect()
# Unfuse the mlp projections
gate_weight, up_weight = torch.chunk(state_dict[f"layers.{l}.mlp.gate_up_proj.weight"], 2, dim=0)
state_dict[f"layers.{l}.mlp.gate_proj.weight"] = gate_weight.clone().detach()
state_dict[f"layers.{l}.mlp.up_proj.weight"] = up_weight.clone().detach()
state_dict.pop(f"layers.{l}.mlp.gate_up_proj.weight")
gc.collect()
if neuron_config.vocab_parallel:
# TODO: this hack can be removed after replication_id is ready to use
state_dict["embed_tokens.rank_util.rank"] = torch.arange(0, neuron_config.local_ranks_size)
# to facilitate rank usage in attention
num_layers = config.num_hidden_layers
tp_degree = neuron_config.tp_degree
for i in range(num_layers):
state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32)
# to facilitate rank usage in base model
state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32)
return state_dict