tinynn/converter/operators/hybrid_quantizer.py (216 lines of code) (raw):

import copy import functools import igraph as ig import numpy as np import torch from tinynn.util.util import class_conditional, get_logger from . import tflite as tfl from .base import ExtendedOperator from .graph import CommonGraph log = get_logger(__name__) WEIGHT_MAPPING = { ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: [1, 2, 3, 4, 5, 6, 7, 8], ExtendedOperator.BIDIRECTIONAL_SEQUENCE_LSTM: [1, 2, 3, 4, 5, 6, 7, 8, 18, 19, 20, 21, 22, 23, 24, 25], } BIAS_MAPPING = { ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: {1: 12, 2: 13, 3: 14, 4: 15}, } STATE_MAPPING = { ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: [18], } CELL_STATE_MAPPING = { ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: [19], } class HybridQuantizer(object): graph: CommonGraph def __init__( self, graph, asymmetric, q_type, per_channel, enable_conv, enable_int16_lstm, gen_single_op_models, config ) -> None: super().__init__() self.graph = graph self.asymmetric = asymmetric self.q_type = q_type self.per_channel = per_channel self.enable_conv = enable_conv self.enable_int16_lstm = enable_int16_lstm self.gen_single_op_models = gen_single_op_models if config is None: config = {} self.config = config def quantize(self): self.quantize_pass() self.int16_lstm_pass() @class_conditional(lambda self: self.enable_int16_lstm) def int16_lstm_pass(self): filtered_nodes = self.graph.graph.vs.select(functools.partial(is_int16_quantizable_lstm_node)) actions = [] replaced_tensors = {} for node in filtered_nodes: if self.config.get(node['outputs'][0], True) is False: continue if node['node_type'] == ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: lstm_input = node['op'].inputs[0] if lstm_input.dtype == np.int8: bias_indices = BIAS_MAPPING.get(node['node_type']) for weight_idx, bias_idx in bias_indices.items(): bias_t = node['op'].inputs[bias_idx] weight_t = node['op'].inputs[weight_idx] name = bias_t.name new_name = f'{name}_hybrid_q' bias_a = np.frombuffer(bias_t.buffer.data, dtype='float32').reshape(bias_t.shape) bias = torch.from_numpy(bias_a.copy()) bias_scale = weight_t.quantization.scale * lstm_input.quantization.scale new_bias = torch.round(bias.detach() / bias_scale).to(dtype=torch.int32) new_bias_t = tfl.Tensor(tfl.FakeQuantTensor(new_bias, bias_scale, 0), new_name) replaced_tensors.setdefault(new_bias_t.name, new_bias_t) new_bias_t = replaced_tensors[new_bias_t.name] actions.append((self.graph.replace_operator_input, (node, bias_idx, new_bias_t))) state_indices = STATE_MAPPING.get(node['node_type']) for state_idx in state_indices: node['op'].inputs[state_idx].quantization = copy.deepcopy(node['op'].outputs[0].quantization) node['op'].inputs[state_idx].tensor = node['op'].inputs[state_idx].tensor.astype(np.int8) node['op'].inputs[state_idx].dtype = node['op'].inputs[state_idx].tensor.dtype cell_state_indices = CELL_STATE_MAPPING.get(node['node_type']) for cell_state_idx in cell_state_indices: q_cell_output = self.graph.rev_q_mapping[node['op'].extra_hints['cell_output']].quantization q_cell_max = q_cell_output.scale * (127 - q_cell_output.zero_point) q_cell_min = q_cell_output.scale * (-128 - q_cell_output.zero_point) q_cell_abs_max = np.maximum(np.abs(q_cell_max), np.abs(q_cell_min)) cell_pot = np.power(2, np.maximum(np.ceil(np.log2(q_cell_abs_max)), 0)).item() node['op'].inputs[cell_state_idx].quantization = tfl.QuantizationParameters(cell_pot / 32768, 0) node['op'].inputs[cell_state_idx].tensor = ( node['op'].inputs[cell_state_idx].tensor.astype(np.int16) ) node['op'].inputs[cell_state_idx].dtype = node['op'].inputs[cell_state_idx].tensor.dtype # Add intermediates for int8x8_16 lstm name = node['op'].outputs[0].name input_to_input_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_1') input_to_forget_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_2') input_to_cell_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_3') input_to_output_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_4') effective_hidden_scale_intermediate = tfl.Tensor( tfl.FakeQuantTensor(np.zeros(0, dtype='int8'), node['op'].outputs[0].quantization.scale, 0), f'{name}_intermediate_5', ) actions.append((self.graph.append_operator_input, (node, input_to_input_intermediate, True))) actions.append((self.graph.append_operator_input, (node, input_to_forget_intermediate, True))) actions.append((self.graph.append_operator_input, (node, input_to_cell_intermediate, True))) actions.append((self.graph.append_operator_input, (node, input_to_output_intermediate, True))) actions.append( (self.graph.append_operator_input, (node, effective_hidden_scale_intermediate, True)) ) for func, args in actions: func(*args) def quantize_pass(self): filtered_nodes = self.graph.graph.vs.select(functools.partial(is_quantizable_node, with_conv=self.enable_conv)) actions = [] replaced_tensors = {} for node in filtered_nodes: if self.config.get(node['outputs'][0], True) is False: continue weight_indices = WEIGHT_MAPPING.get(node['node_type'], [1]) skip = False for weight_idx in weight_indices: new_weight = None weight_t = node['op'].inputs[weight_idx] if weight_t.buffer is None or str(weight_t.dtype) != 'float32': skip = True break if skip: continue for weight_idx in weight_indices: weight_t = node['op'].inputs[weight_idx] name = weight_t.name weight_a = np.frombuffer(weight_t.buffer.data, dtype='float32').reshape(weight_t.shape) weight = torch.from_numpy(weight_a.copy()) if ( node['node_type'] in ( ExtendedOperator.FULLY_CONNECTED, ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM, ExtendedOperator.BIDIRECTIONAL_SEQUENCE_LSTM, ) or not self.per_channel ): if node['node_type'] == ExtendedOperator.DEPTHWISE_CONV_2D: log.warning('DEPTHWISE_CONV_2D doesn\'t support hybrid per-tensor quantization') continue if self.asymmetric and hasattr(node['op'], 'asymmetricQuantizeInputs'): node['op'].asymmetricQuantizeInputs = True if self.q_type == np.uint8: new_weight = quantize(name, weight, torch.qint8, torch.per_tensor_symmetric, q_type=np.int8) new_weight.reinterpret_as(self.q_type) else: new_weight = quantize(name, weight, torch.qint8, torch.per_tensor_symmetric, q_type=self.q_type) elif node['node_type'] == ExtendedOperator.CONV_2D: new_weight = quantize(name, weight, torch.qint8, torch.per_channel_symmetric, 0, q_type=self.q_type) elif node['node_type'] == ExtendedOperator.DEPTHWISE_CONV_2D: new_weight = quantize( name, weight, torch.qint8, torch.per_channel_symmetric, -1, q_type=self.q_type ) if self.gen_single_op_models: node['op'].extra_hints['orig_float'] = copy.deepcopy(node['op']) replaced_tensors.setdefault(new_weight.name, new_weight) new_weight = replaced_tensors[new_weight.name] actions.append((self.graph.replace_operator_input, (node, weight_idx, new_weight))) for func, args in actions: func(*args) def is_quantizable_node(vertex: ig.Vertex, with_conv: bool): return vertex['node_type'] in ( ExtendedOperator.FULLY_CONNECTED, ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM, ExtendedOperator.BIDIRECTIONAL_SEQUENCE_LSTM, ) or ( with_conv and vertex['node_type'] in ( ExtendedOperator.CONV_2D, ExtendedOperator.DEPTHWISE_CONV_2D, ) ) def is_int16_quantizable_lstm_node(vertex: ig.Vertex): return vertex['node_type'] in (ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM,) def quantize(name, tensor, dtype, qscheme, axis=None, q_type=np.uint8): assert qscheme in (torch.per_tensor_symmetric, torch.per_channel_symmetric) new_name = f'{name}_hybrid_q' if dtype == torch.quint8: quant_min, quant_max = 0, 255 else: quant_min, quant_max = -127, 127 if axis is not None: if axis < 0: axis += tensor.ndim dim = [i for i in range(tensor.ndim) if i != axis] else: dim = None if hasattr(torch, 'amin') and hasattr(torch, 'amax'): min_val = torch.amin(tensor, dim) max_val = torch.amax(tensor, dim) else: if dim is None: min_val = torch.min(tensor) max_val = torch.max(tensor) else: orig_dim = tensor.size(axis) if axis != 0: perm = [axis] + dim tensor_perm = tensor.permute(perm) else: tensor_perm = tensor tensor_2d = tensor_perm.reshape(orig_dim, -1) min_val, _ = torch.min(tensor_2d, 1) max_val, _ = torch.max(tensor_2d, 1) min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) scale = torch.ones(min_val_neg.size(), dtype=torch.float32) zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64) eps = torch.tensor(torch.finfo(torch.float32).eps) max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) scale = torch.max(scale, eps) if dtype == torch.quint8: zero_point = zero_point.new_full(zero_point.size(), 128) if qscheme == torch.per_channel_symmetric: q_tensor = torch.quantize_per_channel(tensor, scale, zero_point, axis, dtype) else: q_tensor = torch.quantize_per_tensor(tensor, scale, zero_point, dtype) return tfl.Tensor(q_tensor, new_name, q_type=q_type)