tinynn/graph/quantization/observer.py (107 lines of code) (raw):

import copy from typing import Tuple import numpy as np import torch import torch.nn.functional as F import torch.quantization as torch_q class MinMaxObserver(torch_q.MinMaxObserver): def __init__(self, *args, **kwargs) -> None: super(MinMaxObserver, self).__init__(*args, **kwargs) self.quant_min = -127 self.quant_max = 127 class PerChannelMinMaxObserver(torch_q.PerChannelMinMaxObserver): def __init__(self, *args, **kwargs) -> None: super(PerChannelMinMaxObserver, self).__init__(*args, **kwargs) self.quant_min = -127 self.quant_max = 127 class MovingAverageMinMaxObserver(torch_q.MovingAverageMinMaxObserver): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.quant_min = -127 self.quant_max = 127 class MovingAveragePerChannelMinMaxObserver(torch_q.MovingAveragePerChannelMinMaxObserver): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.quant_min = -127 self.quant_max = 127 class HistogramObserverKL(torch_q.HistogramObserver): def _compute_threshold(self, distribution: np.ndarray, m_bin_number=2048) -> int: """Compute the quantization error using Kullback-Leibler divergence. We filter out outliers in input distribution by searching the threshold for minimizing KL-Divergence. Ref:https://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf Args: distribution: ndarray, the distribution of Calibration Sets normalized to 1. m_bin_number: int, the bins_number of distribution. Returns: threshold: int, the best threshold with the minimum KL. """ m_bin_number = m_bin_number + 1 target_bin_numbers = 128 threshold = target_bin_numbers min_kl_divergence = float('inf') after_threshold_sum = np.sum(distribution[target_bin_numbers:]) cumsum_dist = np.zeros(distribution.size + 1, dtype=distribution.dtype) np.cumsum(distribution, out=cumsum_dist[1:]) cumsum_nozeros = np.zeros(distribution.size + 1, dtype=distribution.dtype) np.cumsum(distribution != 0, out=cumsum_nozeros[1:]) is_nonzero_distribution = distribution != 0 is_nonzero_distribution = np.append(is_nonzero_distribution, False) for i in range(target_bin_numbers, m_bin_number): quantize_dis = np.zeros(target_bin_numbers) expanded_dis = np.zeros(i) candidate_dis = copy.deepcopy(distribution)[:i] candidate_dis[i - 1] = candidate_dis[i - 1] + after_threshold_sum if i != m_bin_number - 1: after_threshold_sum -= distribution[i] else: after_threshold_sum = np.zeros(1) bin_interval = i / target_bin_numbers # merge i bins to target bins j_ = np.arange(target_bin_numbers) start_ = j_ * bin_interval end_ = start_ + bin_interval left_upper_ = np.ceil(start_).astype('int32') right_lower_ = np.floor(end_).astype('int32') left_flag = left_upper_ > start_ right_flag = right_lower_ < end_ left_scale_ = left_upper_ - start_ right_scale_ = end_ - right_lower_ quantize_dis[left_flag] += left_scale_[left_flag] * distribution[left_upper_[left_flag] - 1] quantize_dis[right_flag] += right_scale_[right_flag] * distribution[right_lower_[right_flag]] quantize_dis += cumsum_dist[right_lower_] - cumsum_dist[left_upper_] # expand target bins to i bins count_ = np.zeros(target_bin_numbers) count_[left_flag] += left_scale_[left_flag] * is_nonzero_distribution[left_upper_[left_flag] - 1] count_[right_flag] += right_scale_[right_flag] * is_nonzero_distribution[right_lower_[right_flag]] count_ += cumsum_nozeros[right_lower_] - cumsum_nozeros[left_upper_] to_expand_value_ = np.zeros(target_bin_numbers) count_flag = count_ != 0 to_expand_value_[count_flag] = quantize_dis[count_flag] / count_[count_flag] left_expand_flag = np.logical_and(count_flag, left_flag, is_nonzero_distribution[left_upper_ - 1]) expanded_dis[left_upper_[left_expand_flag] - 1] += ( to_expand_value_[left_expand_flag] * left_scale_[left_expand_flag] ) right_expand_flag = np.logical_and(count_flag, right_flag, is_nonzero_distribution[right_lower_]) expanded_dis[right_lower_[right_expand_flag]] += ( to_expand_value_[right_expand_flag] * right_scale_[right_expand_flag] ) k = np.floor(bin_interval).astype('int32') last_flag = k - (right_lower_ - left_upper_) == 0 for m in range(right_lower_[0] - 1): expanded_dis[left_upper_ + m] += to_expand_value_ * is_nonzero_distribution[left_upper_ + m] expanded_dis[left_upper_ + right_lower_[0] - 1] += ( to_expand_value_ * last_flag * is_nonzero_distribution[left_upper_ + right_lower_[0] - 1] ) # Calculate the Kl divergence of expanded_dis and candidate_dis expanded_dis = torch.from_numpy(expanded_dis) candidate_dis = torch.from_numpy(candidate_dis) curKL = F.kl_div(expanded_dis.log(), candidate_dis, reduction='sum') if curKL < min_kl_divergence and curKL != 1.0: min_kl_divergence = curKL threshold = i return threshold def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: """Search for optimal cutoff range for asymmetric quantization. First, search for right_threshold, then search from right to left for left_threshold by minimizing the KL-d. """ bins = self.histogram.clone() bins[0] = bins[1] bins_np = bins.numpy() bins_np[bins_np < 0] = 0 total_num = np.sum(bins_np) bins_np = bins_np / total_num bin_width = (self.max_val - self.min_val) / self.bins right_threshold = self._compute_threshold(bins_np, 2048) left_dis = bins_np[:right_threshold] left_dis = left_dis[::-1] m_bin_number = left_dis.size left_threshold_ = self._compute_threshold(left_dis, m_bin_number=m_bin_number) left_threshold = right_threshold - left_threshold_ new_min = self.min_val + bin_width * left_threshold new_max = self.min_val + bin_width * right_threshold return new_min, new_max