in video_generation.py [0:0]
def __load_model(self):
# build model
model = vits.__dict__[self.args.arch](
patch_size=self.args.patch_size, num_classes=0
)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(DEVICE)
if os.path.isfile(self.args.pretrained_weights):
state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
if (
self.args.checkpoint_key is not None
and self.args.checkpoint_key in state_dict
):
print(
f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
)
state_dict = state_dict[self.args.checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
print(
"Pretrained weights found at {} and loaded with msg: {}".format(
self.args.pretrained_weights, msg
)
)
else:
print(
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url = None
if self.args.arch == "vit_small" and self.args.patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif self.args.arch == "vit_small" and self.args.patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
elif self.args.arch == "vit_base" and self.args.patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif self.args.arch == "vit_base" and self.args.patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if url is not None:
print(
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
)
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/" + url
)
model.load_state_dict(state_dict, strict=True)
else:
print(
"There is no reference weights available for this model => We use random weights."
)
return model