def _export_yolox()

in easycv/apis/export.py [0:0]


def _export_yolox(model, cfg, filename):
    """ export cls (cls & metric learning)model and preprocess config
    Args:
        model (nn.Module):  model to be exported
        cfg: Config object
        filename (str): filename to save exported models
    """

    if hasattr(cfg, 'export'):
        export_type = getattr(cfg.export, 'export_type', 'raw')
        default_export_type_list = ['raw', 'jit', 'blade', 'onnx']
        if export_type not in default_export_type_list:
            logging.warning(
                'YOLOX-PAI only supports the export type as  [raw,jit,blade,onnx], otherwise we use raw as default'
            )
            export_type = 'raw'

        model.export_type = export_type

        if export_type != 'raw':
            from easycv.utils.misc import reparameterize_models
            # only when we use jit or blade, we need to reparameterize_models before export
            model = reparameterize_models(model)
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            model = copy.deepcopy(model)

            preprocess_jit = cfg.export.get('preprocess_jit', False)

            batch_size = cfg.export.get('batch_size', 1)
            static_opt = cfg.export.get('static_opt', True)
            use_trt_efficientnms = cfg.export.get('use_trt_efficientnms',
                                                  False)
            # assert image scale and assgin input
            img_scale = cfg.get('img_scale', (640, 640))

            assert (
                len(img_scale) == 2
            ), 'Export YoloX predictor config contains img_scale must be (int, int) tuple!'

            input = 255 * torch.rand((batch_size, 3) + tuple(img_scale))

            # assert use_trt_efficientnms only happens when static_opt=True
            if static_opt is not True:
                assert (
                    use_trt_efficientnms == False
                ), 'Export YoloX predictor use_trt_efficientnms=True only when use static_opt=True!'

            # allow to save a preprocess jit model with exported model
            save_preprocess_jit = False

            if preprocess_jit:
                save_preprocess_jit = True

            # set model use_trt_efficientnms
            if use_trt_efficientnms:
                from easycv.toolkit.blade import create_tensorrt_efficientnms
                if hasattr(model, 'get_nmsboxes_num'):
                    nmsbox_num = int(model.get_nmsboxes_num(img_scale))
                else:
                    logging.warning(
                        'PAI-YOLOX: use_trt_efficientnms encounter model has no attr named get_nmsboxes_num, use 8400 (80*80+40*40+20*20)cas default!'
                    )
                    nmsbox_num = 8400

                tmp_example_scores = torch.randn(
                    [batch_size, nmsbox_num, 4 + 1 + len(cfg.CLASSES)],
                    dtype=torch.float32)
                logging.warning(
                    'PAI-YOLOX: use_trt_efficientnms with staic shape [{}, {}, {}]'
                    .format(batch_size, nmsbox_num, 4 + 1 + len(cfg.CLASSES)))
                model.trt_efficientnms = create_tensorrt_efficientnms(
                    tmp_example_scores,
                    iou_thres=model.nms_thre,
                    score_thres=model.test_conf)
                model.use_trt_efficientnms = True

            model.eval()
            model.to(device)

            model_export = ModelExportWrapper(
                model,
                input.to(device),
                trace_model=True,
            )

            model_export.eval().to(device)

            # trace model
            yolox_trace = torch.jit.trace(model_export, input.to(device))

            # save export model
            if export_type == 'blade':
                blade_config = cfg.export.get(
                    'blade_config',
                    dict(enable_fp16=True, fp16_fallback_op_ratio=0.3))

                from easycv.toolkit.blade import blade_env_assert, blade_optimize
                assert blade_env_assert()

                # optimize model with blade
                yolox_blade = blade_optimize(
                    speed_test_model=model,
                    model=yolox_trace,
                    inputs=(input.to(device), ),
                    blade_config=blade_config,
                    static_opt=static_opt)

                with io.open(filename + '.blade', 'wb') as ofile:
                    torch.jit.save(yolox_blade, ofile)
                with io.open(filename + '.blade.config.json', 'w') as ofile:
                    config = dict(
                        model=cfg.model,
                        export=cfg.export,
                        test_pipeline=cfg.test_pipeline,
                        classes=cfg.CLASSES)

                    json.dump(config, ofile)

            if export_type == 'onnx':

                with io.open(
                        filename + '.config.json' if filename.endswith('onnx')
                        else filename + '.onnx.config.json', 'w') as ofile:
                    config = dict(
                        model=cfg.model,
                        export=cfg.export,
                        test_pipeline=cfg.test_pipeline,
                        classes=cfg.CLASSES)

                    json.dump(config, ofile)

                torch.onnx.export(
                    model,
                    input.to(device),
                    filename if filename.endswith('onnx') else filename +
                    '.onnx',
                    export_params=True,
                    opset_version=12,
                    do_constant_folding=True,
                    input_names=['input'],
                    output_names=['output'],
                )

            if export_type == 'jit':
                with io.open(filename + '.jit', 'wb') as ofile:
                    torch.jit.save(yolox_trace, ofile)

                with io.open(filename + '.jit.config.json', 'w') as ofile:
                    config = dict(
                        model=cfg.model,
                        export=cfg.export,
                        test_pipeline=cfg.test_pipeline,
                        classes=cfg.CLASSES)

                    json.dump(config, ofile)

            # save export preprocess/postprocess
            if save_preprocess_jit:
                tpre_input = 255 * torch.rand((batch_size, ) + img_scale +
                                              (3, ))
                tpre = ProcessExportWrapper(
                    example_inputs=tpre_input.to(device),
                    process_fn=PreProcess(
                        target_size=img_scale, keep_ratio=True))
                tpre.eval().to(device)

                preprocess = torch.jit.script(tpre)
                with io.open(filename + '.preprocess', 'wb') as prefile:
                    torch.jit.save(preprocess, prefile)

        else:
            if hasattr(cfg, 'test_pipeline'):
                # with last pipeline Collect
                test_pipeline = cfg.test_pipeline
                print(test_pipeline)
            else:
                print('test_pipeline not found, using default preprocessing!')
                raise ValueError('export model config without test_pipeline')

            config = dict(
                model=cfg.model,
                test_pipeline=test_pipeline,
                CLASSES=cfg.CLASSES,
            )

            meta = dict(config=json.dumps(config))
            checkpoint = dict(
                state_dict=model.state_dict(), meta=meta, author='EasyCV')
            with io.open(filename, 'wb') as ofile:
                torch.save(checkpoint, ofile)