models/imagenet/resnet_cnsn.py [240:266]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return nn.Sequential(*layers)

    def _enable_cross_norm(self):
        active_cn_idxs = np.random.choice(self.cn_num, self.active_num, replace=False).tolist()
        assert len(set(active_cn_idxs)) == self.active_num
        # print('active_cn_idxs: {}'.format(active_cn_idxs))
        for idx in active_cn_idxs:
            self.cn_modules[idx].active = True

    def forward(self, x, aug=False):
        # See note [TorchScript super()]
        if aug:
            # print('forward cross norm...')
            # exit()
            self._enable_cross_norm()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/imagenet/resnet_ibn_cnsn.py [220:245]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return nn.Sequential(*layers)

    def _enable_cross_norm(self):
        active_cn_idxs = np.random.choice(self.cn_num, self.active_num, replace=False).tolist()
        assert len(set(active_cn_idxs)) == self.active_num
        # print('active_cn_idxs: {}'.format(active_cn_idxs))
        for idx in active_cn_idxs:
            self.cn_modules[idx].active = True

    def forward(self, x, aug=False):
        if aug:
            # print('forward cross norm...')
            # exit()
            self._enable_cross_norm()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



