def disable_requantization_for_cat_pass()

in tinynn/graph/quantization/quantizer.py [0:0]


    def disable_requantization_for_cat_pass(self, graph):
        def _find_quantized_cat_nodes(node: TraceNode, custom_node):
            # Find quantized cat nodes
            return node.type() == 'cat' and node.quantized

        # For cat nodes, the `activation_post_process` around it needs to be unified
        quantized_cat_nodes = graph.filter_forward_nodes(_find_quantized_cat_nodes)

        q = queue.Queue()
        visited_center = set()
        for n in quantized_cat_nodes:
            q.put((n, 'both', 0))
            parents = []
            names = []
            props = []
            visited_other = dict()
            while not q.empty():
                n, mode, fq_count = q.get()
                if (
                    n.kind() in ('shape', 'size')
                    or n.unique_name in visited_center
                    or visited_other.get(n.unique_name, 2) <= fq_count
                ):
                    continue

                if n.type() == 'cat':
                    visited_center.add(n.unique_name)
                else:
                    visited_other[n.unique_name] = fq_count

                new_fq_count = fq_count

                if isinstance(n.module, nn.Module):
                    is_prev_float_functional = False
                    orig_name = graph.module_original_name_dict.get(id(n.module))
                    new_mod, parent = graph.get_submodule_with_parent_from_name(orig_name, self.inplace)
                    prop = orig_name.split('.')[-1]
                    if QuantizableLSTM is not None and isinstance(new_mod, QuantizableLSTM):
                        if new_fq_count == 0:
                            if new_mod.bidirectional is False:
                                parents.append(new_mod.layers[-1].layer_fw.cell.ogate_cy)
                                names.append(f'{orig_name}.layer_fw.cell.ogate_cy.activation_post_process')
                                props.append('activation_post_process')
                            else:
                                parents.append(new_mod.layers[-1].layer_fw.cell.ogate_cy)
                                names.append(f'{orig_name}.layer_fw.cell.ogate_cy.activation_post_process')
                                props.append('activation_post_process')
                                parents.append(new_mod.layers[-1].layer_bw.cell.ogate_cy)
                                names.append(f'{orig_name}.layer_bw.cell.ogate_cy.activation_post_process')
                                props.append('activation_post_process')
                        new_fq_count += 1
                    elif QuantizableGRU is not None and isinstance(new_mod, QuantizableGRU):
                        if new_fq_count == 0:
                            if new_mod.bidirectional is False:
                                parents.append(new_mod.layers[-1].layer_fw.cell.add4)
                                names.append(f'{orig_name}.layer_fw.cell.add4.activation_post_process')
                                props.append('activation_post_process')
                            else:
                                parents.append(new_mod.layers[-1].layer_fw.cell.add4)
                                names.append(f'{orig_name}.layer_bw.cell.add4.activation_post_process')
                                props.append('activation_post_process')
                                parents.append(new_mod.layers[-1].layer_bw.cell.add4)
                                names.append(f'{orig_name}.layer_bw.cell.add4.activation_post_process')
                                props.append('activation_post_process')
                        new_fq_count += 1
                    elif isinstance(new_mod, (torch_q.FakeQuantize, torch_q.ObserverBase)):
                        if new_fq_count == 0:
                            parents.append(parent)
                            names.append(orig_name)
                            props.append(prop)
                        new_fq_count += 1
                    elif hasattr(new_mod, 'activation_post_process'):
                        if new_fq_count == 0:
                            parents.append(new_mod)
                            names.append(f'{orig_name}.activation_post_process')
                            props.append('activation_post_process')
                        new_fq_count += 1
                    elif (
                        isinstance(new_mod, nn.Sequential)
                        and type(new_mod).__module__.startswith(nni.__name__)
                        and len(new_mod) > 0
                        and hasattr(new_mod[-1], 'activation_post_process')
                    ):
                        if new_fq_count == 0:
                            parents.append(new_mod[-1])
                            names.append(f'{orig_name}[-1].activation_post_process')
                            props.append('activation_post_process')
                        new_fq_count += 1
                    if isinstance(new_mod, (torch_q.DeQuantStub, torch_q.QuantStub)):
                        new_fq_count = 2
                else:
                    is_prev_float_functional = (
                        len(n.prev_nodes) > 1 and n.prev_nodes[0].type() is torch.nn.quantized.FloatFunctional
                    )
                    if n.type() == 'cat':
                        mode = 'both'
                        fq_count = 0
                        new_fq_count = 0
                    if is_prev_float_functional:
                        m = n.prev_nodes[0].module
                        orig_name = graph.module_original_name_dict.get(id(m))
                        if new_fq_count == 0:
                            parents.append(m)
                            names.append(f'{orig_name}.activation_post_process')
                            props.append('activation_post_process')
                        new_fq_count += 1

                if mode in ('both', 'down'):
                    fq_up = fq_count
                    fq_down = new_fq_count
                elif mode == 'up':
                    fq_up = new_fq_count
                    fq_down = fq_count

                if mode == 'up' and len(n.next_nodes) > 1:
                    mode = 'both'
                    fq_down += 1

                if mode in ('both', 'up'):
                    for i, node in enumerate(n.prev_nodes):
                        if is_prev_float_functional and i == 0:
                            continue
                        if fq_up < 2:
                            q.put((node, 'up', fq_up))
                if mode in ('both', 'down'):
                    for node in n.next_nodes:
                        if fq_down < 2:
                            q.put((node, 'down', fq_down))

            if len(names) > 1:
                log.debug(f'Unifying the following nodes into one: {", ".join(names)}')
                unified = getattr(parents[0], props[0])
                for parent, prop in zip(parents[1:], props[1:]):
                    setattr(parent, prop, unified)