def collect_hooks_()

in modules/SwissArmyTransformer/sat/model/base_model.py [0:0]


    def collect_hooks_(self):
        names = list(HOOKS_DEFAULT.keys())
        hooks = {}
        hook_origins = {}
        for name in names:
            if hasattr(self, name):
                hooks[name] = getattr(self, name)
                hook_origins[name] = 'model'

            for mixin_name, m in self.mixins.items():
                if hasattr(m, name):
                    if hasattr(getattr(m, name), 'non_conflict'):
                        # check getattr(m, name), who must accept old_impl as an argument
                        signature = inspect.signature(getattr(m, name))
                        if 'old_impl' not in signature.parameters:
                            raise ValueError(f'Hook {name} at {mixin_name} must accept old_impl as an argument.')
                        # -------------
                        if name in hooks:
                            old_impl = hooks[name]
                        elif name == 'attention_fn': # the only hook without self
                            old_impl = HOOKS_DEFAULT[name]
                        else:
                            old_impl = partial(HOOKS_DEFAULT[name], self) # relax! `partial` does not affect the signature
                        old_origin = hook_origins.get(name, 'default')
                        hooks[name] = partial(getattr(m, name), old_impl=old_impl)
                        hook_origins[name] = mixin_name + ' -> ' + old_origin
                    elif name in hooks and not hasattr(hooks[name], 'replacable'): # if this hook name is already registered
                        raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
                    else: # new hook
                        if name in hooks and hasattr(hooks[name], 'replacable'):
                            warnings.warn(f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.')
                        hooks[name] = getattr(m, name)
                        hook_origins[name] = mixin_name

        self.hooks = hooks
        self.hook_origins = hook_origins
        return hooks