in timm/layers/create_attn.py [0:0]
def get_attn(attn_type):
if isinstance(attn_type, torch.nn.Module):
return attn_type
module_cls = None
if attn_type:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial).
# Typically added to existing network architecture blocks in addition to existing convolutions.
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'ese':
module_cls = EffectiveSEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'ecam':
module_cls = partial(EcaModule, use_mlp=True)
elif attn_type == 'ceca':
module_cls = CecaModule
elif attn_type == 'ge':
module_cls = GatherExcite
elif attn_type == 'gc':
module_cls = GlobalContext
elif attn_type == 'gca':
module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':
module_cls = LightCbamModule
# Attention / attention-like modules w/ significant params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'sk':
module_cls = SelectiveKernel
elif attn_type == 'splat':
module_cls = SplitAttn
# Self-attention / attention-like modules w/ significant compute and/or params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'nl':
module_cls = NonLocalAttn
elif attn_type == 'bat':
module_cls = BatNonLocalAttn
# Woops!
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):
if attn_type:
module_cls = SEModule
else:
module_cls = attn_type
return module_cls