toolkits/distributed_checkpoints_convertor/impl/deepseek_v3/patch.py (25 lines of code) (raw):
from torch import nn
from accelerate import init_empty_weights
class NormedLinear(nn.Module):
def __init__(self, norm_class, config):
super().__init__()
self.norm = norm_class(config.hidden_size, eps=config.rms_norm_eps)
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@init_empty_weights(include_buffers=True)
def add_mtp_layers(hfmodel, config, mtp_num_layers):
basic_decoder_layer_class = hfmodel.model.layers[-1].__class__
start_layer_id = config.num_hidden_layers
for mtp_layer_id in range(start_layer_id, start_layer_id + mtp_num_layers):
hfmodel.model.layers.append(
basic_decoder_layer_class(config, mtp_layer_id)
)
# NOTE: patch some special attributes
mtp_layer: nn.Module = hfmodel.model.layers[-1]
mtp_layer.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
norm_class = hfmodel.model.norm.__class__
mtp_layer.enorm = norm_class(config.hidden_size, eps=config.rms_norm_eps)
mtp_layer.hnorm = norm_class(config.hidden_size, eps=config.rms_norm_eps)
mtp_layer.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
mtp_layer.shared_head = NormedLinear(norm_class, config)
return hfmodel