def main()

in eval_retrieval_feature_extract.py [0:0]


def main(args):
    data_loader, total_num = get_loader(args)
    logger.info('using data: {}'.format(len(data_loader)))

    model_config_dict = dict(
        num_classes=128,
        mlp=True,
    )
    model = resnet50(**model_config_dict).cuda()
    model = DistributedDataParallel(model, device_ids=[args.local_rank])
    load_pretrained(args, model)
    model.eval()

    logger.info('model init done')

    all_feat = []
    all_feat_cls = np.zeros([len(data_loader)], dtype=np.int32)

    with torch.no_grad():
        for idx, (data, cls) in enumerate(data_loader):
            logger.info('{}/{}'.format(idx, len(data_loader)))
            # data: B * S * C * H * W
            data = data.cuda()
            feat = model(data, layer=args.layer, tsn_mode=True).view(-1)

            all_feat.append(feat.data.cpu().numpy())
            all_feat_cls[idx] = cls.item()

    all_feat = np.stack(all_feat, axis=0)
    np.save(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.datamode, args.local_rank)), all_feat)
    np.save(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.datamode, args.local_rank)), all_feat_cls)

    if dist.get_rank() == 0:
        np.save(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.datamode)), np.array([total_num]))