in siammot/modelling/backbone/dla.py [0:0]
def __init__(self, levels, block, in_channels, out_channels, stride=1,
dilation=1, cardinality=1, base_width=64,
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False,
batch_norm=FrozenBatchNorm2d, with_dcn=False):
super(DlaTree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width, batch_norm=batch_norm, with_dcn=with_dcn)
if levels == 1:
self.tree1 = block(in_channels, out_channels, stride, **cargs)
self.tree2 = block(out_channels, out_channels, 1, **cargs)
else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
self.tree1 = DlaTree(
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
self.tree2 = DlaTree(
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
if levels == 1:
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual, batch_norm=batch_norm)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else None
self.project = None
if in_channels != out_channels:
self.project = nn.Sequential(
Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
batch_norm(out_channels)
)
self.levels = levels