in pycls/models/resnet.py [0:0]
def _construct_class(self, dim_in, dim_out, stride, num_bs, dim_inner, num_gs):
if cfg.RGRAPH.KEEP_GRAPH:
seed = cfg.RGRAPH.SEED_GRAPH
else:
seed = int(cfg.RGRAPH.SEED_GRAPH * 100)
for i in range(num_bs):
# Stride and dim_in apply to the first block of the stage
b_stride = stride if i == 0 else 1
b_dim_in = dim_in if i == 0 else dim_out
# Retrieve the transformation function
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
# Construct the block
res_block = ResBlock(
b_dim_in, dim_out, b_stride, trans_fun, dim_inner, num_gs, seed=seed
)
if not cfg.RGRAPH.KEEP_GRAPH:
seed += 1
self.add_module('b{}'.format(i + 1), res_block)
for j in range(cfg.RGRAPH.ADD_1x1):
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN + '1x1')
# Construct the block
res_block = ResBlock(
dim_out, dim_out, 1, trans_fun, dim_inner, num_gs, seed=seed
)
if not cfg.RGRAPH.KEEP_GRAPH:
seed += 1
self.add_module('b{}_{}1x1'.format(i + 1, j + 1), res_block)