toolkits/model_checkpoints_convertor/glm130b/merge_130b_ckpts.py (96 lines of code) (raw):
import torch
import pdb
output_state_dict = {}
rank_00 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_00_model_states.pt', map_location='cpu')
rank_01 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_01_model_states.pt', map_location='cpu')
rank_02 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_02_model_states.pt', map_location='cpu')
rank_03 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_03_model_states.pt', map_location='cpu')
rank_04 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_04_model_states.pt', map_location='cpu')
rank_05 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_05_model_states.pt', map_location='cpu')
rank_06 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_06_model_states.pt', map_location='cpu')
rank_07 = torch.load('/mnt/glm-ckpts/glm-130b-sat/49300/mp_rank_07_model_states.pt', map_location='cpu')
sd_split={0:rank_00, 1:rank_01, 2:rank_02, 3:rank_03, 4:rank_04, 5:rank_05, 6:rank_06, 7:rank_07}
word_embedding_00 = rank_00['module']['transformer.word_embeddings.weight']
word_embedding_01 = rank_01['module']['transformer.word_embeddings.weight']
word_embedding_02 = rank_02['module']['transformer.word_embeddings.weight']
word_embedding_03 = rank_03['module']['transformer.word_embeddings.weight']
word_embedding_04 = rank_04['module']['transformer.word_embeddings.weight']
word_embedding_05 = rank_05['module']['transformer.word_embeddings.weight']
word_embedding_06 = rank_06['module']['transformer.word_embeddings.weight']
word_embedding_07 = rank_07['module']['transformer.word_embeddings.weight']
word_embedding = torch.cat((word_embedding_00,
word_embedding_01,
word_embedding_02,
word_embedding_03,
word_embedding_04,
word_embedding_05,
word_embedding_06,
word_embedding_07,
), dim=0)
output_state_dict['transformer.word_embeddings.weight'] = word_embedding
for layer_id in range(70):
input_layernorm_weight = 'transformer.layers.' + str(layer_id) + '.input_layernorm.weight'
output_state_dict[input_layernorm_weight] = rank_00['module'][input_layernorm_weight]
print(input_layernorm_weight)
input_layernorm_bias = 'transformer.layers.' + str(layer_id) + '.input_layernorm.bias'
output_state_dict[input_layernorm_bias] = rank_00['module'][input_layernorm_bias]
print(input_layernorm_bias)
self_att_qkv_weight = 'transformer.layers.' + str(
layer_id) + '.attention.query_key_value.weight'
self_att_qkv_bias = 'transformer.layers.' + str(
layer_id) + '.attention.query_key_value.bias'
self_att_dense_weight = 'transformer.layers.' + str(
layer_id) + '.attention.dense.weight'
tmp_qkv_weight = []
tmp_qkv_bias = []
tmp_dense_weight = []
for i in range(8):
rand_i = sd_split[i]
tmp_qkv_weight.append(rand_i['module'][self_att_qkv_weight])
tmp_qkv_bias.append(rand_i['module'][self_att_qkv_bias])
tmp_dense_weight.append(rand_i['module'][self_att_dense_weight])
output_state_dict[self_att_qkv_weight] = torch.cat(tmp_qkv_weight, dim=0)
output_state_dict[self_att_qkv_bias] = torch.cat(tmp_qkv_bias, dim=0)
output_state_dict[self_att_dense_weight] = torch.cat(tmp_dense_weight, dim=1)
print(self_att_qkv_weight)
print(self_att_qkv_bias)
print(self_att_dense_weight)
self_att_dense_bias = 'transformer.layers.' + str(layer_id) + '.attention.dense.bias'
output_state_dict[self_att_dense_bias] = rank_00['module'][self_att_dense_bias]
print(self_att_dense_bias)
post_layernorm_weight = 'transformer.layers.' + str(layer_id) + '.post_attention_layernorm.weight'
output_state_dict[post_layernorm_weight] = rank_00['module'][post_layernorm_weight]
print(post_layernorm_weight)
post_layernorm_bias = 'transformer.layers.' + str(layer_id) + '.post_attention_layernorm.bias'
output_state_dict[post_layernorm_bias] = rank_00['module'][post_layernorm_bias]
print(post_layernorm_bias)
mlp_h_weight = 'transformer.layers.' + str(layer_id) + '.mlp.dense_h_to_4h.weight'
mlp_h_bias = 'transformer.layers.' + str(layer_id) + '.mlp.dense_h_to_4h.bias'
tmp_mlp_weight = []
tmp_mlp_bias = []
for i in range(8):
rand_i = sd_split[i]
tmp_mlp_weight.append(rand_i['module'][mlp_h_weight])
tmp_mlp_bias.append(rand_i['module'][mlp_h_bias])
output_state_dict[mlp_h_weight] = torch.cat(tmp_mlp_weight, dim=0)
output_state_dict[mlp_h_bias] = torch.cat(tmp_mlp_bias, dim=0)
print(mlp_h_weight)
print(mlp_h_bias)
mlp_4h_weight = 'transformer.layers.' + str(layer_id) + '.mlp.dense_4h_to_h.weight'
tmp_mlp_4h_weight = []
for i in range(8):
rand_i = sd_split[i]
tmp_mlp_4h_weight.append(rand_i['module'][mlp_4h_weight])
output_state_dict[mlp_4h_weight] = torch.cat(tmp_mlp_4h_weight, dim=1)
print(mlp_4h_weight)
mlp_4h_bias = 'transformer.layers.' + str(layer_id) + '.mlp.dense_4h_to_h.bias'
output_state_dict[mlp_4h_bias] = rank_00['module'][mlp_4h_bias]
print(mlp_4h_bias)
output_state_dict['transformer.final_layernorm.weight'] = rank_00['module'][
'transformer.final_layernorm.weight']
print('final_layernorm.weight')
output_state_dict['transformer.final_layernorm.bias'] = rank_00['module'][
'transformer.final_layernorm.bias']
print('final_layernorm.bias')
torch.save(output_state_dict, "/mnt/glm-ckpts/glm-130b-sat/pytorch_model.bin")
print("done")