in pycls/models/resnet.py [0:0]
def _construct_class(self, dim_in, dim_out, stride):
# ReLU, 3x3, BN, 1x1, BN
self.a_3x3 = nn.Conv2d(
dim_in, dim_in, kernel_size=3,
stride=stride, padding=1, bias=False, groups=dim_in
)
self.a_1x1 = TalkConv2d(
dim_in, dim_out, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
stride=1, padding=0, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.a_1x1_bn = nn.BatchNorm2d(dim_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
# ReLU, 3x3, BN, 1x1, BN
self.b_3x3 = nn.Conv2d(
dim_out, dim_out, kernel_size=3,
stride=1, padding=1, bias=False, groups=dim_out
)
self.b_1x1 = TalkConv2d(
dim_out, dim_out, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
stride=1, padding=0, bias=False,
message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
)
self.b_1x1_bn = nn.BatchNorm2d(dim_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_1x1_bn.final_bn = True