in src/transformers/models/vidtr/vidtr_split.py [0:0]
def inflate_model(model, weights_dir, temporal_size, merge_later):
    r""" Inflate ViT.
    Args:
        model (CompactVidTr): model to be inflated.
        weights_dir (str): inflate model dir.
        temporal_size (int): Number of frames in temporal.
    """
    checkpoint = torch.load(weights_dir, map_location='cpu')
    model_dict = model.state_dict()
    pretrained_dict = {}
    for k, v in checkpoint.items():
        if 'pos_embedding' in k:
            if merge_later:
                pretrained_dict.update({"pos_embedding": torch.cat((v[:, :1, :], v[:, 1:, :].repeat(1, temporal_size, 1)), dim=1)})
            else:
                pretrained_dict.update({"pos_embedding": v.repeat(1, temporal_size + 1, 1)})
        elif 'cls' in k:
            if not merge_later:
                pretrained_dict.update({"cls_s": v.repeat(model.cls_s.shape[0], 1, 1)})
                pretrained_dict.update({"cls_t": v.repeat(1, model.cls_t.shape[1], 1)})
            else:
                pretrained_dict.update({"cls": v})
        elif 'fc' not in k:
            pretrained_dict.update({k: v})
    pretrained_dict_ = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict_)
    model.load_state_dict(model_dict)
    print("Inflate model success.")