in optimum/neuron/models/inference/mixtral/modeling_mixtral.py [0:0]
def convert_mixtral_to_neuron_state_dict(neuron_state_dict, config, neuron_config):
"""
Helper function which returns the model weights from the mixtral model in a state dictionary compatible with the stucture of the neuron MoE model.
"""
assert neuron_config.glu_mlp is True, "Only GLU MLP is supported for Mixtral Top-K model"
for l in range(config.num_hidden_layers): # noqa: E741
# Copy router weights
neuron_state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"].detach().clone()
)
del neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"]
intermediate_size, hidden_size = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].shape
device = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].device
dtype = neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.0.w1.weight"].dtype
# copy the MLP parameters
gate_up_proj = torch.empty(
config.num_local_experts,
hidden_size,
2 * intermediate_size,
dtype=dtype,
device=device,
)
for e in range(config.num_local_experts):
# Copy gate_proj and up_proj after concatenation
gate_proj_weights = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"].T.detach().clone()
)
up_proj_weights = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"].T.detach().clone()
)
gate_up_proj_slice = torch.narrow(gate_up_proj, 0, e, 1)
gate_proj_slice = torch.narrow(gate_up_proj_slice, 2, 0, intermediate_size)
gate_proj_slice.copy_(gate_proj_weights)
up_proj_slice = torch.narrow(gate_up_proj_slice, 2, intermediate_size, intermediate_size)
up_proj_slice.copy_(up_proj_weights)
del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"]
del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"]
neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj
down_proj = torch.empty(
config.num_local_experts,
intermediate_size,
hidden_size,
dtype=dtype,
device=device,
)
for e in range(config.num_local_experts):
# Copy down_proj
down_proj_weights = (
neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"].T.detach().clone()
)
down_proj_slice = torch.narrow(down_proj, 0, e, 1)
down_proj_slice.copy_(down_proj_weights)
del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"]
neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj
gc.collect()
return neuron_state_dict