def test_main()

in pt/vmz/func/test.py [0:0]


def test_main(args):
    torchvision.set_video_backend("video_reader")
    if args.output_dir:
        utils.mkdir(args.output_dir)

    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    transform_test = torchvision.transforms.Compose(
        [
            T.ToTensorVideo(),
            T.Resize((args.scale_h, args.scale_w)),
            T.NormalizeVideo(
                mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)
            ),
            T.CenterCropVideo((args.crop_size, args.crop_size)),
        ]
    )

    print("Loading validation data")
    if os.path.isfile(args.val_file):
        metadata = torch.load(args.val_file)
        root = args.valdir

    # TODO: add test option fro datasets that support that
    dataset_test = get_dataset(args, transform_test, "val")

    dataset_test.video_clips.compute_clips(args.num_frames, 1, frame_rate=15)

    test_sampler = UniformClipSampler(
        dataset_test.video_clips, args.val_clips_per_video
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
    )

    print("Creating model")
    # TODO: model only from our models
    available_models = {**models.__dict__}

    model = available_models[args.model](pretraining=args.pretrained)
    model.to(device)
    model_without_ddp = model

    model = torch.nn.parallel.DataParallel(model)
    model_without_ddp = model.module

    criterion = nn.CrossEntropyLoss()

    # model pretrained or this
    if not args.pretrained:
        print(f"Loading the model from {args.resume_from_model}")
        checkpoint = torch.load(args.resume_from_model, map_location="cpu")
        if "model" in checkpoint.keys():
            model_without_ddp.load_state_dict(checkpoint["model"])
        else:
            model_without_ddp.load_state_dict(checkpoint)

    print("Starting test_only")
    metric_logger = log.MetricLogger(delimiter="  ", writer=None, stat_set="val")
    test(model, criterion, data_loader_test, device, 2, metric_logger)