in modules/SwissArmyTransformer/sat/model/base_model.py [0:0]
def collect_hooks_(self):
names = list(HOOKS_DEFAULT.keys())
hooks = {}
hook_origins = {}
for name in names:
if hasattr(self, name):
hooks[name] = getattr(self, name)
hook_origins[name] = 'model'
for mixin_name, m in self.mixins.items():
if hasattr(m, name):
if hasattr(getattr(m, name), 'non_conflict'):
# check getattr(m, name), who must accept old_impl as an argument
signature = inspect.signature(getattr(m, name))
if 'old_impl' not in signature.parameters:
raise ValueError(f'Hook {name} at {mixin_name} must accept old_impl as an argument.')
# -------------
if name in hooks:
old_impl = hooks[name]
elif name == 'attention_fn': # the only hook without self
old_impl = HOOKS_DEFAULT[name]
else:
old_impl = partial(HOOKS_DEFAULT[name], self) # relax! `partial` does not affect the signature
old_origin = hook_origins.get(name, 'default')
hooks[name] = partial(getattr(m, name), old_impl=old_impl)
hook_origins[name] = mixin_name + ' -> ' + old_origin
elif name in hooks and not hasattr(hooks[name], 'replacable'): # if this hook name is already registered
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
else: # new hook
if name in hooks and hasattr(hooks[name], 'replacable'):
warnings.warn(f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.')
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
self.hooks = hooks
self.hook_origins = hook_origins
return hooks