in pt/vmz/func/test.py [0:0]
def test_main(args):
torchvision.set_video_backend("video_reader")
if args.output_dir:
utils.mkdir(args.output_dir)
print(args)
print("torch version: ", torch.__version__)
print("torchvision version: ", torchvision.__version__)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
transform_test = torchvision.transforms.Compose(
[
T.ToTensorVideo(),
T.Resize((args.scale_h, args.scale_w)),
T.NormalizeVideo(
mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)
),
T.CenterCropVideo((args.crop_size, args.crop_size)),
]
)
print("Loading validation data")
if os.path.isfile(args.val_file):
metadata = torch.load(args.val_file)
root = args.valdir
# TODO: add test option fro datasets that support that
dataset_test = get_dataset(args, transform_test, "val")
dataset_test.video_clips.compute_clips(args.num_frames, 1, frame_rate=15)
test_sampler = UniformClipSampler(
dataset_test.video_clips, args.val_clips_per_video
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
sampler=test_sampler,
num_workers=args.workers,
)
print("Creating model")
# TODO: model only from our models
available_models = {**models.__dict__}
model = available_models[args.model](pretraining=args.pretrained)
model.to(device)
model_without_ddp = model
model = torch.nn.parallel.DataParallel(model)
model_without_ddp = model.module
criterion = nn.CrossEntropyLoss()
# model pretrained or this
if not args.pretrained:
print(f"Loading the model from {args.resume_from_model}")
checkpoint = torch.load(args.resume_from_model, map_location="cpu")
if "model" in checkpoint.keys():
model_without_ddp.load_state_dict(checkpoint["model"])
else:
model_without_ddp.load_state_dict(checkpoint)
print("Starting test_only")
metric_logger = log.MetricLogger(delimiter=" ", writer=None, stat_set="val")
test(model, criterion, data_loader_test, device, 2, metric_logger)