in d2go/export/torchscript.py [0:0]
def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
force_disable_tracing_adapter = export_kwargs.pop(
"force_disable_tracing_adapter", False
)
is_trace_mode = export_kwargs.get("jit_mode", "trace") == "trace"
if force_disable_tracing_adapter or not is_trace_mode:
logger.info("Not trace mode, export normally")
return old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
if _is_data_flattened_tensors(input_args):
logger.info("Dry run the model to check if TracingAdapter is needed ...")
outputs = model(*input_args)
if _is_data_flattened_tensors(outputs):
logger.info(
"Both inputs and outputs are flattened tensors, export the model as is."
)
load_kwargs = old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
assert "tracing_adapted" not in load_kwargs
load_kwargs.update({"tracing_adapted": False})
return load_kwargs
else:
logger.info(
"The outputs are not flattened tensors, can't trace normally."
)
else:
logger.info("The inputs are not flattened tensors, can't trace normally.")
logger.warning(
"Wrap the model with TracingAdapter to handle non-flattened inputs/outputs,"
" please be aware that the exported model will have different input/output data structure."
)
adapter = TracingAdapter(model, input_args)
load_kwargs = old_f(
cls,
adapter,
adapter.flattened_inputs,
save_path,
export_method,
**export_kwargs,
)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
assert "tracing_adapted" not in load_kwargs
assert "inputs_schema" not in load_kwargs
assert "outputs_schema" not in load_kwargs
load_kwargs.update(
{
"tracing_adapted": True,
"inputs_schema": inputs_schema,
"outputs_schema": outputs_schema,
}
)
return load_kwargs
return new_f