in pycls/models/efficientnet.py [0:0]
def _construct_class(self, in_w, exp_r, kernel, stride, se_r, out_w, act_fun, exp_w):
# Expansion: 1x1, BN, Swish
self.expand = None
if int(exp_r)==1:
exp_w = in_w
else:
self.expand = TalkConv2d(
in_w, exp_w, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
stride=1, padding=0, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.expand_bn = nn.BatchNorm2d(
exp_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
)
self.expand_swish = act_fun()
# Depthwise: 3x3 dwise, BN, Swish
self.dwise = nn.Conv2d(
exp_w, exp_w,
kernel_size=kernel, stride=stride, groups=exp_w, bias=False,
# Hacky padding to preserve res (supports only 3x3 and 5x5)
padding=(1 if kernel == 3 else 2)
)
self.dwise_bn = nn.BatchNorm2d(
exp_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
)
self.dwise_swish = act_fun()
# SE: x * F_ex(x)
if cfg.EFFICIENT_NET.SE_ENABLED:
se_w = int(in_w * se_r)
self.se = SE(exp_w, se_w, act_fun)
# Linear projection: 1x1, BN
self.lin_proj = TalkConv2d(
exp_w, out_w, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
stride=1, padding=0, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.lin_proj_bn = nn.BatchNorm2d(
out_w, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
)
# Nonlinear projection
if not cfg.EFFICIENT_NET.LIN_PROJ:
self.lin_proj_swish = act_fun()
# Skip connections on blocks w/ same in and out shapes (MN-V2, Fig. 4)
self.has_skip = (stride == 1) and (in_w == out_w)