recipes/utilities/convlm_serializer/save_pytorch_model.py (54 lines of code) (raw):
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
from collections import defaultdict
import torch
def convert(model_state, key, suffix=""):
string = ""
param = model_state[key]
# param name
string += ".".join(key.split(".")[1:-1]) + suffix + "." + key.split(".")[-1] + " "
change_to_lin_layer = False
if "conv" in key and len(param.shape) == 3:
if ("weight_v" in key and param.shape[0] == 1) or (
"weight_g" in key
and model_state[key.replace("weight_g", "weight_v")].shape[0] == 1
):
change_to_lin_layer = True
if change_to_lin_layer:
# param shapes
string += (
str(len(param.shape) - 1) + " " + " ".join(map(str, param.shape[1:][::-1]))
)
# param matrix
string += " " + " ".join(map(str, param.cpu().numpy()[0].T.flatten()))
else:
# param shapes
string += str(len(param.shape)) + " " + " ".join(map(str, param.shape))
# param matrix
string += " " + " ".join(map(str, param.cpu().numpy().flatten()))
return string
def save_model(pytorch_model_path, dst):
model_state = torch.load(pytorch_model_path)
model_state = model_state["model"]
add_string = ""
prev_key = ""
with open(dst, "w") as f:
projections = defaultdict(list)
for key in model_state:
print("Process param", key)
if "version" in key:
print("Skip", key)
continue
if "projection" in key:
projections[key.split(".")[-2]].append(
convert(model_state, key, "-projection")
)
else:
if prev_key != key.split(".")[2]:
if add_string != "":
f.write(add_string + "\n")
add_string = ""
prev_key = key.split(".")[2]
if key.split(".")[2] in projections:
add_string = "\n".join(projections[key.split(".")[2]])
f.write(convert(model_state, key) + "\n")
if __name__ == "__main__":
print("Converting the model. Usage: save_pytorch_model.py [path/to/model] [dst]")
path = sys.argv[1]
dst = sys.argv[2]
save_model(path, dst)