in pycls/models/resnet.py [0:0]
def _construct_class(self, dim_in, dim_out, stride, dim_inner, num_gs, seed):
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
# (str1x1, str3x3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
(str1x1, str3x3) = (1, stride)
# 1x1, BN, ReLU
self.a = TalkConv2d(
dim_in, dim_inner, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
stride=str1x1, padding=0, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.a_bn = nn.BatchNorm2d(
dim_inner, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 3x3, BN, ReLU
self.b = TalkConv2d(
dim_inner, dim_inner, cfg.RGRAPH.GROUP_NUM, kernel_size=3,
stride=str3x3, padding=1, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.b_bn = nn.BatchNorm2d(
dim_inner, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# 1x1, BN
self.c = TalkConv2d(
dim_inner, dim_out, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
stride=1, padding=0, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.c_bn = nn.BatchNorm2d(dim_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True