in recipes/utilities/convlm_serializer/save_pytorch_model.py [0:0]
def save_model(pytorch_model_path, dst):
model_state = torch.load(pytorch_model_path)
model_state = model_state["model"]
add_string = ""
prev_key = ""
with open(dst, "w") as f:
projections = defaultdict(list)
for key in model_state:
print("Process param", key)
if "version" in key:
print("Skip", key)
continue
if "projection" in key:
projections[key.split(".")[-2]].append(
convert(model_state, key, "-projection")
)
else:
if prev_key != key.split(".")[2]:
if add_string != "":
f.write(add_string + "\n")
add_string = ""
prev_key = key.split(".")[2]
if key.split(".")[2] in projections:
add_string = "\n".join(projections[key.split(".")[2]])
f.write(convert(model_state, key) + "\n")