in models/regnet.py [0:0]
def __init__(self, params: AnyNetParams):
super().__init__()
activation = {
ActivationType.RELU: nn.ReLU(params.relu_in_place),
ActivationType.SILU: nn.SiLU(),
}[params.activation]
if activation is None:
raise RuntimeError("SiLU activation is only supported since PyTorch 1.7")
assert params.num_classes is None or is_pos_int(params.num_classes)
# Ad hoc stem
self.stem = {
StemType.RES_STEM_CIFAR: ResStemCifar,
StemType.RES_STEM_IN: ResStemIN,
StemType.SIMPLE_STEM_IN: SimpleStemIN,
}[params.stem_type](
3,
params.stem_width,
params.bn_epsilon,
params.bn_momentum,
activation,
)
# Instantiate all the AnyNet blocks in the trunk
block_fun = {
BlockType.VANILLA_BLOCK: VanillaBlock,
BlockType.RES_BASIC_BLOCK: ResBasicBlock,
BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
BlockType.RES_BOTTLENECK_LINEAR_BLOCK: ResBottleneckLinearBlock,
}[params.block_type]
current_width = params.stem_width
self.trunk_depth = 0
blocks = []
for i, (
width_out,
stride,
depth,
group_width,
bottleneck_multiplier,
) in enumerate(params.get_expanded_params()):
blocks.append(
(
f"block{i+1}",
AnyStage(
current_width,
width_out,
stride,
depth,
block_fun,
activation,
group_width,
bottleneck_multiplier,
params,
stage_index=i + 1,
),
)
)
self.trunk_depth += blocks[-1][1].stage_depth
current_width = width_out
self.trunk_output = nn.Sequential(OrderedDict(blocks))
# Init weights and good to go
self.init_weights()
# If head, create
if params.num_classes is not None:
self.head = FullyConnectedHead(
num_classes=params.num_classes, in_plane=current_width
)
else:
self.head = None