modules/SwissArmyTransformer/sat/ops/__init__.py (26 lines of code) (raw):
# -*- encoding: utf-8 -*-
'''
@File : __init__.py
@Time : 2022/06/03 23:01:46
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
# dynamic import according to name with importlib
from importlib import import_module
avaliable_ops = {
'LayerNorm': 'sat.ops.layernorm',
'f_similar': 'sat.ops.local_attention_function',
'f_weighting': 'sat.ops.local_attention_function',
'FusedScaleMaskSoftmax': 'sat.ops.scaled_mask_softmax',
'FusedEmaAdam': 'sat.ops.fused_ema_adam',
'memory_efficient_attention': 'sat.ops.memory_efficient_attention',
}
for name, path in avaliable_ops.items():
# define some objects with the same name as the ops
# so that we can use them as a placeholder
# when __call__() is called, the real ops will be imported and called
locals()[name] = type(name + 'Shell', (object,), {
'__init__': lambda self, *args, **kwargs: None,
'__call__': lambda self, *args, **kwargs: getattr(import_module(self.path), self.name)(*args, **kwargs),
'name': name,
'path': path,
})()