in tinynn/graph/quantization/quantizer.py [0:0]
def disable_requantization_for_cat_pass(self, graph):
def _find_quantized_cat_nodes(node: TraceNode, custom_node):
# Find quantized cat nodes
return node.type() == 'cat' and node.quantized
# For cat nodes, the `activation_post_process` around it needs to be unified
quantized_cat_nodes = graph.filter_forward_nodes(_find_quantized_cat_nodes)
q = queue.Queue()
visited_center = set()
for n in quantized_cat_nodes:
q.put((n, 'both', 0))
parents = []
names = []
props = []
visited_other = dict()
while not q.empty():
n, mode, fq_count = q.get()
if (
n.kind() in ('shape', 'size')
or n.unique_name in visited_center
or visited_other.get(n.unique_name, 2) <= fq_count
):
continue
if n.type() == 'cat':
visited_center.add(n.unique_name)
else:
visited_other[n.unique_name] = fq_count
new_fq_count = fq_count
if isinstance(n.module, nn.Module):
is_prev_float_functional = False
orig_name = graph.module_original_name_dict.get(id(n.module))
new_mod, parent = graph.get_submodule_with_parent_from_name(orig_name, self.inplace)
prop = orig_name.split('.')[-1]
if QuantizableLSTM is not None and isinstance(new_mod, QuantizableLSTM):
if new_fq_count == 0:
if new_mod.bidirectional is False:
parents.append(new_mod.layers[-1].layer_fw.cell.ogate_cy)
names.append(f'{orig_name}.layer_fw.cell.ogate_cy.activation_post_process')
props.append('activation_post_process')
else:
parents.append(new_mod.layers[-1].layer_fw.cell.ogate_cy)
names.append(f'{orig_name}.layer_fw.cell.ogate_cy.activation_post_process')
props.append('activation_post_process')
parents.append(new_mod.layers[-1].layer_bw.cell.ogate_cy)
names.append(f'{orig_name}.layer_bw.cell.ogate_cy.activation_post_process')
props.append('activation_post_process')
new_fq_count += 1
elif QuantizableGRU is not None and isinstance(new_mod, QuantizableGRU):
if new_fq_count == 0:
if new_mod.bidirectional is False:
parents.append(new_mod.layers[-1].layer_fw.cell.add4)
names.append(f'{orig_name}.layer_fw.cell.add4.activation_post_process')
props.append('activation_post_process')
else:
parents.append(new_mod.layers[-1].layer_fw.cell.add4)
names.append(f'{orig_name}.layer_bw.cell.add4.activation_post_process')
props.append('activation_post_process')
parents.append(new_mod.layers[-1].layer_bw.cell.add4)
names.append(f'{orig_name}.layer_bw.cell.add4.activation_post_process')
props.append('activation_post_process')
new_fq_count += 1
elif isinstance(new_mod, (torch_q.FakeQuantize, torch_q.ObserverBase)):
if new_fq_count == 0:
parents.append(parent)
names.append(orig_name)
props.append(prop)
new_fq_count += 1
elif hasattr(new_mod, 'activation_post_process'):
if new_fq_count == 0:
parents.append(new_mod)
names.append(f'{orig_name}.activation_post_process')
props.append('activation_post_process')
new_fq_count += 1
elif (
isinstance(new_mod, nn.Sequential)
and type(new_mod).__module__.startswith(nni.__name__)
and len(new_mod) > 0
and hasattr(new_mod[-1], 'activation_post_process')
):
if new_fq_count == 0:
parents.append(new_mod[-1])
names.append(f'{orig_name}[-1].activation_post_process')
props.append('activation_post_process')
new_fq_count += 1
if isinstance(new_mod, (torch_q.DeQuantStub, torch_q.QuantStub)):
new_fq_count = 2
else:
is_prev_float_functional = (
len(n.prev_nodes) > 1 and n.prev_nodes[0].type() is torch.nn.quantized.FloatFunctional
)
if n.type() == 'cat':
mode = 'both'
fq_count = 0
new_fq_count = 0
if is_prev_float_functional:
m = n.prev_nodes[0].module
orig_name = graph.module_original_name_dict.get(id(m))
if new_fq_count == 0:
parents.append(m)
names.append(f'{orig_name}.activation_post_process')
props.append('activation_post_process')
new_fq_count += 1
if mode in ('both', 'down'):
fq_up = fq_count
fq_down = new_fq_count
elif mode == 'up':
fq_up = new_fq_count
fq_down = fq_count
if mode == 'up' and len(n.next_nodes) > 1:
mode = 'both'
fq_down += 1
if mode in ('both', 'up'):
for i, node in enumerate(n.prev_nodes):
if is_prev_float_functional and i == 0:
continue
if fq_up < 2:
q.put((node, 'up', fq_up))
if mode in ('both', 'down'):
for node in n.next_nodes:
if fq_down < 2:
q.put((node, 'down', fq_down))
if len(names) > 1:
log.debug(f'Unifying the following nodes into one: {", ".join(names)}')
unified = getattr(parents[0], props[0])
for parent, prop in zip(parents[1:], props[1:]):
setattr(parent, prop, unified)