def finetunable_GPT2_params()

in opacus_lab/models/GPT2/train.py [0:0]


def finetunable_GPT2_params(model, finetune):
    # works on refactored GPT2
    def extract_finetune_index(name):
        # subroutine that parses string
        ft_idx = None
        if "emb" in name:
            ft_idx = -1
        elif name.startswith("transformers"):
            ft_idx = int(name.split(".")[1])
        elif "head" in name:
            ft_idx = float("inf")  # always FT the head
        return ft_idx

    params = []
    for name, param in model.named_parameters():
        if extract_finetune_index(name) >= finetune and param.requires_grad:
            params.append(param)
        else:
            param.requires_grad = False
    return params