def tracing_adapter_wrap_export()

in d2go/export/torchscript.py [0:0]


def tracing_adapter_wrap_export(old_f):
    def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
        force_disable_tracing_adapter = export_kwargs.pop(
            "force_disable_tracing_adapter", False
        )
        is_trace_mode = export_kwargs.get("jit_mode", "trace") == "trace"
        if force_disable_tracing_adapter or not is_trace_mode:
            logger.info("Not trace mode, export normally")
            return old_f(
                cls, model, input_args, save_path, export_method, **export_kwargs
            )

        if _is_data_flattened_tensors(input_args):
            logger.info("Dry run the model to check if TracingAdapter is needed ...")
            outputs = model(*input_args)
            if _is_data_flattened_tensors(outputs):
                logger.info(
                    "Both inputs and outputs are flattened tensors, export the model as is."
                )
                load_kwargs = old_f(
                    cls, model, input_args, save_path, export_method, **export_kwargs
                )
                assert "tracing_adapted" not in load_kwargs
                load_kwargs.update({"tracing_adapted": False})
                return load_kwargs
            else:
                logger.info(
                    "The outputs are not flattened tensors, can't trace normally."
                )
        else:
            logger.info("The inputs are not flattened tensors, can't trace normally.")

        logger.warning(
            "Wrap the model with TracingAdapter to handle non-flattened inputs/outputs,"
            " please be aware that the exported model will have different input/output data structure."
        )
        adapter = TracingAdapter(model, input_args)
        load_kwargs = old_f(
            cls,
            adapter,
            adapter.flattened_inputs,
            save_path,
            export_method,
            **export_kwargs,
        )
        inputs_schema = dump_dataclass(adapter.inputs_schema)
        outputs_schema = dump_dataclass(adapter.outputs_schema)
        assert "tracing_adapted" not in load_kwargs
        assert "inputs_schema" not in load_kwargs
        assert "outputs_schema" not in load_kwargs
        load_kwargs.update(
            {
                "tracing_adapted": True,
                "inputs_schema": inputs_schema,
                "outputs_schema": outputs_schema,
            }
        )
        return load_kwargs

    return new_f