def main()

in benchmarks/tools/extract.py [0:0]


def main():
    args = parse_args()
    # set cudnn_benchmark
    cfg = mmcv_config_fromfile(args.config)

    if cfg.get('oss_io_config', None):
        io.access_oss(**cfg.oss_io_config)

    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir

    # checkpoint and pretrained are exclusive
    assert args.pretrained == 'random' or args.checkpoint is None, \
        'Checkpoint and pretrained are exclusive.'

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        if args.launcher == 'slurm':
            cfg.dist_params['port'] = args.port
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # logger
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, 'extract_{}.log'.format(timestamp))
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    datasets = [build_dataset(cfg.data.extract)]
    seed = 0
    set_random_seed(seed)
    data_loader = [
        build_dataloader(
            ds,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus,
            dist=distributed,
            shuffle=False,
            replace=getattr(cfg.data, 'sampling_replace', False),
            seed=seed,
            drop_last=getattr(cfg.data, 'drop_last', False)) for ds in datasets
    ]

    # specify pretrained model
    if args.pretrained != 'random':
        assert isinstance(args.pretrained, str)
        cfg.model.pretrained = args.pretrained

    assert os.path.exists(args.checkpoint), \
        'checkpoint must be set when use extractor!'
    ckpt_meta = torch.load(args.checkpoint).get('meta', None)
    cfg_model = cfg.get('model', None)

    if cfg_model is not None:
        logger.info('load model scripts from cfg config')
        model = build_model(cfg_model)
    else:
        assert ckpt_meta is not None, 'extract need either cfg model or ckpt with meta!'
        logger.info('load model scripts from ckpt meta')
        ckpt_cfg = json.loads(ckpt_meta['config'])
        if 'model' not in ckpt_cfg:
            raise ValueError(
                'build model from %s, must use model after export' %
                (args.checkpoint))
        model = build_model(ckpt_cfg['model'])

    # build the model and load checkpoint

    if args.checkpoint is not None:
        logger.info('Use checkpoint: {} to extract features'.format(
            args.checkpoint))
        load_checkpoint(model, args.checkpoint, map_location='cpu')
    elif args.pretrained != 'random':
        logger.info('Use pretrained model: {} to extract features'.format(
            args.pretrained))
    else:
        logger.info('No checkpoint or pretrained is give, use random init.')

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
    else:
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False)

    # build extraction processor
    extractor = ExtractProcess(extract_list=args.extract_list)

    # run
    outputs = extractor.extract(model, data_loader[0], distributed=distributed)

    rank, _ = get_dist_info()
    mmcv.mkdir_or_exist(args.work_dir)

    if rank == 0:
        for key, val in outputs.items():
            split_num = len(cfg.split_name)
            split_at = cfg.split_at
            print(split_num, split_at)
            for ss in range(split_num):
                output_file = '{}/{}_{}.npy'.format(args.work_dir,
                                                    cfg.split_name[ss], key)
                if ss == 0:
                    np.save(output_file, val[:split_at[0]])
                elif ss == split_num - 1:
                    np.save(output_file, val[split_at[-1]:])
                else:
                    np.save(output_file, val[split_at[ss - 1]:split_at[ss]])