def get_model()

in src/retrieval_utils.py [0:0]


def get_model(args, get_video_encoder_only=True, logger=None):
    
    # Load model
    model = load_model(
        vid_base_arch=args.vid_base_arch, 
        aud_base_arch=args.aud_base_arch, 
        pretrained=args.pretrained,
        num_classes=args.num_clusters,
        norm_feat=False,
        use_mlp=args.use_mlp,
        headcount=args.headcount
    )

    # Load model weights
    start = time.time()
    weight_path_type = type(args.weights_path)
    if weight_path_type == str:
        weight_path_not_none = args.weights_path != 'None' 
    else:
        weight_path_not_none = args.weights_path is not None
    if weight_path_not_none:
        print("Loading model weights")
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            model_weights = ckpt_dict["model"]
            args.ckpt_epoch = ckpt_dict['epoch']
            print(f"Epoch checkpoint: {args.ckpt_epoch}", flush=True)
            utils.load_model_parameters(model, model_weights)
    print(f"Time to load model weights: {time.time() - start}")

    # Put model in eval mode
    model.eval()

    # Get video encoder for video-only retrieval
    if get_video_encoder_only:
        model = model.video_network.base
        if args.pool_op == 'max': 
            pool = torch.nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
        elif args.pool_op == 'avg': 
            pool = torch.nn.AvgPool3d((2, 2, 2), stride=(2, 2, 2))
        else:
            assert("Only 'max' and 'avg' pool operations allowed")

        # Set up model
        model = torch.nn.Sequential(*[
            model.stem,
            model.layer1,
            model.layer2,
            model.layer3,
            model.layer4,
            pool,
            Flatten(),
        ])

    if torch.cuda.is_available():
        model = model.cuda()
        model = torch.nn.DataParallel(model)
    return model