in src/retrieval_utils.py [0:0]
def get_model(args, get_video_encoder_only=True, logger=None):
# Load model
model = load_model(
vid_base_arch=args.vid_base_arch,
aud_base_arch=args.aud_base_arch,
pretrained=args.pretrained,
num_classes=args.num_clusters,
norm_feat=False,
use_mlp=args.use_mlp,
headcount=args.headcount
)
# Load model weights
start = time.time()
weight_path_type = type(args.weights_path)
if weight_path_type == str:
weight_path_not_none = args.weights_path != 'None'
else:
weight_path_not_none = args.weights_path is not None
if weight_path_not_none:
print("Loading model weights")
if os.path.exists(args.weights_path):
ckpt_dict = torch.load(args.weights_path)
model_weights = ckpt_dict["model"]
args.ckpt_epoch = ckpt_dict['epoch']
print(f"Epoch checkpoint: {args.ckpt_epoch}", flush=True)
utils.load_model_parameters(model, model_weights)
print(f"Time to load model weights: {time.time() - start}")
# Put model in eval mode
model.eval()
# Get video encoder for video-only retrieval
if get_video_encoder_only:
model = model.video_network.base
if args.pool_op == 'max':
pool = torch.nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
elif args.pool_op == 'avg':
pool = torch.nn.AvgPool3d((2, 2, 2), stride=(2, 2, 2))
else:
assert("Only 'max' and 'avg' pool operations allowed")
# Set up model
model = torch.nn.Sequential(*[
model.stem,
model.layer1,
model.layer2,
model.layer3,
model.layer4,
pool,
Flatten(),
])
if torch.cuda.is_available():
model = model.cuda()
model = torch.nn.DataParallel(model)
return model