def inflate_model()

in src/transformers/models/vidtr/vidtr_split.py [0:0]


def inflate_model(model, weights_dir, temporal_size, merge_later):

    r""" Inflate ViT.

    Args:
        model (CompactVidTr): model to be inflated.
        weights_dir (str): inflate model dir.
        temporal_size (int): Number of frames in temporal.
    """

    checkpoint = torch.load(weights_dir, map_location='cpu')
    model_dict = model.state_dict()

    pretrained_dict = {}
    for k, v in checkpoint.items():
        if 'pos_embedding' in k:
            if merge_later:
                pretrained_dict.update({"pos_embedding": torch.cat((v[:, :1, :], v[:, 1:, :].repeat(1, temporal_size, 1)), dim=1)})
            else:
                pretrained_dict.update({"pos_embedding": v.repeat(1, temporal_size + 1, 1)})
        elif 'cls' in k:
            if not merge_later:
                pretrained_dict.update({"cls_s": v.repeat(model.cls_s.shape[0], 1, 1)})
                pretrained_dict.update({"cls_t": v.repeat(1, model.cls_t.shape[1], 1)})
            else:
                pretrained_dict.update({"cls": v})
        elif 'fc' not in k:
            pretrained_dict.update({k: v})

    pretrained_dict_ = {k: v for k, v in pretrained_dict.items() if k in model_dict}

    model_dict.update(pretrained_dict_)
    model.load_state_dict(model_dict)
    print("Inflate model success.")