def main()

in eval_svm_feature_extract.py [0:0]


def main(args):
    data_loader, total_num = get_loader(args)
    logger.info('using data: {}'.format(len(data_loader)))
    model_config_dict = dict(
        num_classes=128,
        mlp=True,
        intra_out=True,
        order_out=True,
        tsn_out=True,
    )
    model = resnet50(**model_config_dict).cuda()
    model = DistributedDataParallel(model, device_ids=[args.local_rank])
    load_pretrained(args, model)
    model.eval()

    logger.info('model init done')
    # all_feat = np.zeros([len(data_loader), 2048], dtype=np.float32)
    all_feat = []
    all_feat_cls = np.zeros([len(data_loader)], dtype=np.int32)
    with torch.no_grad():
        for idx, (data, cls) in enumerate(data_loader):
            logger.info('{}/{}'.format(idx, len(data_loader)))
            data_size = data.size()
            if data_size[1] != 3:
                data = data.view((-1, 3, data_size[-2], data_size[-1]))
            data = data.cuda()
            feat = model(data, layer=6).squeeze()
            feat_avg = torch.mean(feat, dim=0).view(-1)
            all_feat.append(feat_avg.data.cpu().numpy())
            all_feat_cls[idx] = cls.item()
    all_feat = np.stack(all_feat, axis=0)
    np.save(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.datasplit, args.local_rank)), all_feat)
    np.save(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.datasplit, args.local_rank)), all_feat_cls)

    if dist.get_rank() == 0:
        np.save(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.datasplit)), np.array([total_num]))