def convert()

in recipes/utilities/convlm_serializer/save_pytorch_model.py [0:0]


def convert(model_state, key, suffix=""):
    string = ""
    param = model_state[key]

    # param name
    string += ".".join(key.split(".")[1:-1]) + suffix + "." + key.split(".")[-1] + " "
    change_to_lin_layer = False
    if "conv" in key and len(param.shape) == 3:
        if ("weight_v" in key and param.shape[0] == 1) or (
            "weight_g" in key
            and model_state[key.replace("weight_g", "weight_v")].shape[0] == 1
        ):
            change_to_lin_layer = True
    if change_to_lin_layer:
        # param shapes
        string += (
            str(len(param.shape) - 1) + " " + " ".join(map(str, param.shape[1:][::-1]))
        )
        # param matrix
        string += " " + " ".join(map(str, param.cpu().numpy()[0].T.flatten()))
    else:
        # param shapes
        string += str(len(param.shape)) + " " + " ".join(map(str, param.shape))
        # param matrix
        string += " " + " ".join(map(str, param.cpu().numpy().flatten()))
    return string