in levit.py [0:0]
def __init__(self, img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=[192],
key_dim=[64],
depth=[12],
num_heads=[3],
attn_ratio=[2],
mlp_ratio=[2],
hybrid_backbone=None,
down_ops=[],
attention_activation=torch.nn.Hardswish,
mlp_activation=torch.nn.Hardswish,
distillation=True,
drop_path=0):
super().__init__()
global FLOPS_COUNTER
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = hybrid_backbone
self.blocks = []
down_ops.append([''])
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(Attention(
ed, kd, nh,
attn_ratio=ar,
activation=attention_activation,
resolution=resolution,
), drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(torch.nn.Sequential(
Linear_BN(ed, h, resolution=resolution),
mlp_activation(),
Linear_BN(h, ed, bn_weight_init=0,
resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3],
activation=attention_activation,
stride=do[5],
resolution=resolution,
resolution_=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(torch.nn.Sequential(
Linear_BN(embed_dim[i + 1], h,
resolution=resolution),
mlp_activation(),
Linear_BN(
h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = torch.nn.Sequential(*self.blocks)
# Classifier head
self.head = BN_Linear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if distillation:
self.head_dist = BN_Linear(
embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
self.FLOPS = FLOPS_COUNTER
FLOPS_COUNTER = 0