in python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py [0:0]
def create(self, base_cls: BaseQuantizer) -> BaseQuantizer:
"""Create adaptive quantizer
Parameters
----------
base_cls: BaseQuantizer
The base quantizer class
Returns
-------
quantizer_cls: BaseQuantizer
The quantizer class.
"""
@msc_utils.register_tool
class Quantizer(base_cls):
"""Adaptive quantizer for tensorrt"""
def setup(self) -> dict:
"""Setup the tool
Returns
-------
info: dict
The setup info.
"""
if self._plan:
self._use_range = all(
info.get("use_range", False) for info in self._plan.values()
)
else:
self._use_range = True
return super().setup()
def _reset(
self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]]
) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
"""Reset the tool
Parameters
----------
graphs: list<MSCgraph>
The msc graphs.
weights: list<dict<str, tvm.nd.array>>
The weights
Returns
-------
graphs: list<MSCgraph>
The msc graphs.
weights: list<dict<str, tvm.nd.array>>
The weights
"""
config_folder = msc_utils.get_config_dir()
self._range_files = [config_folder.relpath(g.name + ".range") for g in graphs]
calibrate_root = msc_utils.get_dataset_dir().create_dir("Calibrate")
self._calibrate_folders = [calibrate_root.relpath(g.name) for g in graphs]
if self._calibrated:
if self._use_range:
for r_file, graph in zip(self._range_files, graphs):
if not os.path.isfile(r_file):
self._plan_to_range(graph, r_file)
self._logger.debug(
"G[%s](%s) use range file: %s",
graph.name,
self._stage,
r_file,
)
else:
self._quantized_tensors = set()
elif self._stage == QuantizeStage.GATHER:
self._calibrate_savers = []
for folder, graph in zip(self._calibrate_folders, graphs):
saver_options = {"input_names": [i.name for i in graph.get_inputs()]}
saver = msc_utils.IODataSaver(folder, saver_options)
self._calibrate_savers.append(saver)
self._logger.debug(
"G[%s](%s) create calibrate saver: %s",
graph.name,
self._stage,
saver,
)
else:
assert all(
msc_utils.is_io_dataset(f) for f in self._calibrate_folders
), "Some IODataset missing: " + str(self._calibrate_folders)
return super()._reset(graphs, weights)
def _execute_after_build(self, codegen_context: dict) -> dict:
"""Execute after model build
Parameters
----------
codegen_context: dict
The context.
Returns
----------
codegen_context: dict
The processed context.
"""
if self._stage == QuantizeStage.GATHER and self._forward_cnt == 0:
return codegen_context
if not self._use_range:
return codegen_context
processed = ["// Set int8 calibrator"]
range_file = self.get_graph().name + ".range"
version = [int(v) for v in codegen_context["version"].split(".")]
if msc_utils.compare_version(version, [6, 0, 0]) >= 0:
configer = codegen_context["config"]
else:
configer = codegen_context["builder"]
# check the range file if calibrated
if self._calibrated:
processed.extend(
[
'if (!FileUtils::FileExist("{}")) {{'.format(range_file),
' logger.log(ILogger::Severity::kERROR, "{} not exist!");'.format(
range_file
),
" return -1;",
"}",
]
)
processed.extend(
[
'MSCInt8EntropyCalibrator2 calibrator("{}", "{}");'.format(
range_file, self._calibrate_folders[self._graph_id]
),
"{}->setInt8Calibrator(&calibrator);".format(configer),
]
)
codegen_context["processed"].extend(processed)
return codegen_context
def _execute_before_forward(self, step_context: dict) -> dict:
"""Execute before model forward
Parameters
----------
step_context: dict
The context.
Returns
----------
step_context: dict
The processed context.
"""
if self._stage == QuantizeStage.GATHER:
saver = self._calibrate_savers[self._graph_id]
saver.save_batch(
{name: data.numpy() for name, data in step_context["datas"].items()}
)
for name, data in step_context["datas"].items():
self.debug_tensors(name, "any", "ctx_gather", {"gather": data})
super()._execute_before_forward(step_context)
def _quantize_tensor(
self,
tensor_ctx: Dict[str, str],
name: str,
consumer: str,
strategys: List[ToolStrategy],
) -> Dict[str, str]:
"""Quantize tensor
Parameters
-------
tensor_ctx: dict<str, str>
Tensor describe items.
name: str
The name of the tensor.
consumer: str
The name of the consumer.
strategys: list<ToolStrategy>
The strategys for the tensor.
Returns
-------
tensor_ctx: dict<str, str>
Tensor items with processed.
"""
if not self._use_range and name not in self._quantized_tensors:
self._quantized_tensors.add(name)
return super()._quantize_tensor(tensor_ctx, name, consumer, strategys)
return tensor_ctx
def calibrate(self) -> dict:
"""Calibrate the datas
Returns
-------
plan: dict
The calibrated plan.
"""
for r_file, graph in zip(self._range_files, self._graphs):
self._range_to_plan(graph, r_file)
self._calibrated, self._forward_cnt = True, 0
self.change_stage("quantize")
return self._plan
def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]:
"""Update the generate configs
Parameters
----------
generate_config: dict<str, Any>
The generate_config.
Returns
-------
generate_config: dict<str, Any>
The updated generate_config.
"""
if self._calibrated:
if self._use_range:
for config, r_file in zip(generate_config["codegen"], self._range_files):
if os.path.isfile(r_file):
config.update({"range_file": r_file, "precision": "int8"})
elif self._stage == QuantizeStage.GATHER and self._forward_cnt > 0:
for config, saver, r_file in zip(
generate_config["codegen"], self._calibrate_savers, self._range_files
):
saver.finalize()
msg = "Save {} batch to {}".format(self._forward_cnt, saver.folder)
self._logger.debug(self.msg_mark(msg, in_forward=False))
config.update(
{"dataset": saver.folder, "range_file": r_file, "precision": "int8"}
)
self.change_stage(QuantizeStage.CALIBRATE)
return generate_config
def _plan_to_range(self, graph: MSCGraph, range_file: str, title="MSCCalibrate"):
"""Extract plan config to range_file
Parameters
----------
plan: dict
The plan.
graph: MSCGraph
The graph.
range_file: str
The output range_file path.
title: str
The title of the range file.
"""
def _scale_to_hex(scale):
return hex(struct.unpack("<I", struct.pack("<f", scale / 127))[0])[2:]
recorded = set()
with open(range_file, "w") as f:
f.write(title + "\n")
for name, info in self._plan.items():
t_name, _ = self.from_tensor_id(name)
if not graph.find_tensor(t_name):
continue
if t_name not in recorded:
f.write("{}: {}\n").format(t_name, _scale_to_hex(info["scale"]))
recorded.add(t_name)
self._logger.debug(
"Graph[%s](%s) extract %d plan to range %s",
graph.name,
self._stage,
len(recorded),
range_file,
)
def _range_to_plan(self, graph: MSCGraph, range_file: str):
"""Extract scale in range_file to plan
Parameters
----------
graph: MSCGraph
The graph.
range_file: str
The input range_file path.
"""
range_num = 0
with open(range_file, "r") as f:
f.readline()
line = f.readline()
while line:
name, scale = line.split(": ")
scale = scale.strip()
if scale == "0":
value = 0.0
else:
value = struct.unpack("!f", bytes.fromhex(scale))[0] * 127
range_num += 1
consumers = graph.find_consumers(name)
if consumers:
for c in consumers:
self._plan[self.to_tensor_id(name, c.name)] = {
"scale": value,
"use_range": True,
}
else:
self._plan[self.to_tensor_id(name, "exit")] = {
"scale": value,
"use_range": True,
}
line = f.readline()
self._logger.debug(
"Graph[%s](%s) extract %d range to plan from %s",
graph.name,
self._stage,
range_num,
range_file,
)
@classmethod
def framework(cls):
return MSCFramework.TENSORRT
return Quantizer