in tinynn/util/quantization_analysis_util.py [0:0]
def graph_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine'):
"""Generates the cumulative quant error report using the given metric, the q_model need to be qat_prepared.
Args:
q_model: The quant prepared model.
dummy_input: A viable input to the model
metric: Metrics for measuring the error of floating point tensor and quantized tensor.
Default to be 'cosine', optional 'sqnr'.
"""
if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel):
model = q_model.module
else:
model = q_model
metric_fn = METRIC_DICT[metric]
train_flag = model.training
model.eval()
with torch.no_grad():
modules_list = {}
names_list = {}
results = {}
hooks = []
def forward_hook(module, input, output):
name = names_list[module]
results[name] = input
fake_quant_enabled_dict = {}
observer_enabled_dict = {}
for n, m in model.named_modules():
if isinstance(m, torch.quantization.FakeQuantize):
names_list[m] = n
modules_list[n] = m
fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
observer_enabled_dict[m] = m.observer_enabled.clone()
hooks.append(m.register_forward_hook(forward_hook))
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.quantization.disable_observer)
if len(modules_list) == 0:
log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?')
device = get_module_device(model)
if type(dummy_input) is torch.Tensor:
actual_input = [dummy_input]
elif isinstance(dummy_input, (tuple, list)):
actual_input = list(dummy_input)
else:
log.error(f'Unsupported type {type(dummy_input)} for dummy input')
assert False
for i in range(len(actual_input)):
dummy_input = actual_input[i]
if type(dummy_input) is torch.Tensor:
if dummy_input.device != device:
actual_input[i] = dummy_input.to(device)
model(*actual_input)
# Restore fake-quantize and record activation with quantization error.
for m, v in fake_quant_enabled_dict.items():
m.fake_quant_enabled = v
float_results = results
results = {}
model(*actual_input)
for h in hooks:
h.remove()
hooks.clear()
q_errors_activ = []
for name, f_tensor in float_results.items():
assert name in results, f'{name} not in results'
actual_n = '.'.join(name.split('.')[:-1])
loss = metric_fn(f_tensor[0], results[name][0])
if not name.endswith('.weight_fake_quant'):
q_errors_activ.append((actual_n, modules_list[name], loss))
error_print(metric, q_errors_activ, [], '')
for m, v in observer_enabled_dict.items():
m.observer_enabled = v
if train_flag:
model.train()