def save_model()

in recipes/utilities/convlm_serializer/save_pytorch_model.py [0:0]


def save_model(pytorch_model_path, dst):
    model_state = torch.load(pytorch_model_path)
    model_state = model_state["model"]
    add_string = ""
    prev_key = ""

    with open(dst, "w") as f:
        projections = defaultdict(list)
        for key in model_state:
            print("Process param", key)
            if "version" in key:
                print("Skip", key)
                continue
            if "projection" in key:
                projections[key.split(".")[-2]].append(
                    convert(model_state, key, "-projection")
                )
            else:
                if prev_key != key.split(".")[2]:
                    if add_string != "":
                        f.write(add_string + "\n")
                    add_string = ""
                prev_key = key.split(".")[2]
                if key.split(".")[2] in projections:
                    add_string = "\n".join(projections[key.split(".")[2]])
                f.write(convert(model_state, key) + "\n")