in src/transformers/models/vidtr/vidtr_split.py [0:0]
def inflate_model(model, weights_dir, temporal_size, merge_later):
r""" Inflate ViT.
Args:
model (CompactVidTr): model to be inflated.
weights_dir (str): inflate model dir.
temporal_size (int): Number of frames in temporal.
"""
checkpoint = torch.load(weights_dir, map_location='cpu')
model_dict = model.state_dict()
pretrained_dict = {}
for k, v in checkpoint.items():
if 'pos_embedding' in k:
if merge_later:
pretrained_dict.update({"pos_embedding": torch.cat((v[:, :1, :], v[:, 1:, :].repeat(1, temporal_size, 1)), dim=1)})
else:
pretrained_dict.update({"pos_embedding": v.repeat(1, temporal_size + 1, 1)})
elif 'cls' in k:
if not merge_later:
pretrained_dict.update({"cls_s": v.repeat(model.cls_s.shape[0], 1, 1)})
pretrained_dict.update({"cls_t": v.repeat(1, model.cls_t.shape[1], 1)})
else:
pretrained_dict.update({"cls": v})
elif 'fc' not in k:
pretrained_dict.update({k: v})
pretrained_dict_ = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict_)
model.load_state_dict(model_dict)
print("Inflate model success.")