def _construct_class()

in pycls/models/resnet.py [0:0]


    def _construct_class(self, dim_in, dim_out, stride, dim_inner, num_gs, seed):
        # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
        # (str1x1, str3x3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
        (str1x1, str3x3) = (1, stride)

        # 1x1, BN, ReLU
        self.a = TalkConv2d(
            dim_in, dim_inner, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
            stride=str1x1, padding=0, bias=False,
            message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
            sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
        )
        self.a_bn = nn.BatchNorm2d(
            dim_inner, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
        )
        self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)

        # 3x3, BN, ReLU
        self.b = TalkConv2d(
            dim_inner, dim_inner, cfg.RGRAPH.GROUP_NUM, kernel_size=3,
            stride=str3x3, padding=1, bias=False,
            message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
            sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
        )
        self.b_bn = nn.BatchNorm2d(
            dim_inner, eps=cfg.BN.EPS, momentum=cfg.BN.MOM
        )
        self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)

        # 1x1, BN
        self.c = TalkConv2d(
            dim_inner, dim_out, cfg.RGRAPH.GROUP_NUM, kernel_size=1,
            stride=1, padding=0, bias=False,
            message_type=cfg.RGRAPH.MESSAGE_TYPE, directed=cfg.RGRAPH.DIRECTED, agg=cfg.RGRAPH.AGG_FUNC,
            sparsity=cfg.RGRAPH.SPARSITY, p=cfg.RGRAPH.P, talk_mode=cfg.RGRAPH.TALK_MODE, seed=self.seed
        )
        self.c_bn = nn.BatchNorm2d(dim_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
        self.c_bn.final_bn = True