def _init_repr_shape()

in mobile_cv/lut/lib/pt/flops_utils.py [0:0]


    def _init_repr_shape(self):
        """Add patchers to modify nn.Module.extra_repr to add additional info
        when the model is printted and the patchers are enabled.
        """
        assert len(self._patchers) == 0

        REPR_ITEMS = ["input_shapes", "output_shapes", "nparams", "nflops"]

        def decor_extra_repr(orig_extra_repr):
            def new_extra_repr(module):
                info = self._hook.data(module)
                info_str = []
                if info is not None:
                    # input and output shapes
                    info_str = [f"{k}={v}" for k, v in info.items() if k in REPR_ITEMS]

                ret = orig_extra_repr(module)
                info_str = ", ".join(info_str)
                if len(ret) > 0 and len(info_str) > 0:
                    ret = ret + ",\n"
                ret += info_str
                return ret

            return new_extra_repr

        def _get_unique_parents(types):
            """Mocking a class method (patch.object) may fail in some cases when
              handling derived classes. To avoid this issue, we need to:
              1. For all subclasses that did not overwrite the class method,
                 we should only mock the base class method.
              2. Mock the sublcass method if it is overwritten.
            This could be done by grouping all the classes by their methods
              that will be mocked, and remove classes that are subclasses of
              others.
            Inputs: [(Class, method), ...]
            Outputs: Filtered input
            """
            assert all(isinstance(x, tuple) for x in types)
            types_unique = {x[1]: [] for x in types}
            for cur_type, cur_method in types:
                types_unique[cur_method].append(cur_type)

            ret = []
            for method, cur_types in types_unique.items():
                filtered = get_unique_parent_types(cur_types)
                ret.extend([(ct, method) for ct in filtered])

            return ret

        # nn.Module subclass types that needs to be patched
        # {(module type, module extra_repr function)}
        all_types = set()
        self.model.apply(lambda m: all_types.add((type(m), type(m).extra_repr)))
        all_types = _get_unique_parents(all_types)

        self._patchers = [
            mock.patch.object(
                x, "extra_repr", side_effect=decor_extra_repr(mt), autospec=True
            )
            for x, mt in all_types
        ]