in vissl/models/trunks/regnet_fsdp.py [0:0]
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config):
assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
trunk_config = model_config.TRUNK.REGNET
if "name" in trunk_config:
assert (
trunk_config["name"] == "anynet"
), "Please use AnyNetParams or specify RegNetParams dictionary"
if "name" in trunk_config and trunk_config["name"] == "anynet":
params = AnyNetParams(
depths=trunk_config["depths"],
widths=trunk_config["widths"],
group_widths=trunk_config["group_widths"],
bottleneck_multipliers=trunk_config["bottleneck_multipliers"],
strides=trunk_config["strides"],
stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()],
stem_width=trunk_config.get("stem_width", 32),
block_type=BlockType[
trunk_config.get("block_type", "res_bottleneck_block").upper()
],
activation=ActivationType[trunk_config.get("activation", "relu").upper()],
use_se=trunk_config.get("use_se", True),
se_ratio=trunk_config.get("se_ratio", 0.25),
bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
bn_momentum=trunk_config.get("bn_momentum", 0.1),
)
else:
params = RegNetParams(
depth=trunk_config["depth"],
w_0=trunk_config["w_0"],
w_a=trunk_config["w_a"],
w_m=trunk_config["w_m"],
group_width=trunk_config["group_width"],
bottleneck_multiplier=trunk_config.get("bottleneck_multiplier", 1.0),
stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()],
stem_width=trunk_config.get("stem_width", 32),
block_type=BlockType[
trunk_config.get("block_type", "res_bottleneck_block").upper()
],
activation=ActivationType[trunk_config.get("activation", "relu").upper()],
use_se=trunk_config.get("use_se", True),
se_ratio=trunk_config.get("se_ratio", 0.25),
bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
bn_momentum=trunk_config.get("bn_momentum", 0.1),
)
# Ad hoc stem
#
# Important: do NOT retain modules in self.stem or self.trunk_output. It may
# seem to be harmless, but it appears that autograd will result in computing
# grads in different order. Different ordering can cause deterministic OOM,
# even when the peak memory otherwise is only 24GB out of 32GB.
#
# When debugging this, it is not enough to just dump the total module
# params. You need to diff the module string representations.
stem = factory.create_stem(params)
# Instantiate all the AnyNet blocks in the trunk
current_width, trunk_depth, blocks = params.stem_width, 0, []
for i, (width_out, stride, depth, group_width, bottleneck_multiplier) in enumerate(
params.get_expanded_params()
):
# Starting from 1
stage_index = i + 1
# Identify where the block groups start and end, and whether they should
# be surrounded by activation checkpoints
# A block group is a group of block that is surrounded by a FSDP wrapper
# and optionally an activation checkpoint wrapper
with_checkpointing = (
model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING
)
all_group_delimiters = trunk_config.get("stage_checkpoints", [])
all_group_checkpoint = trunk_config.get("stage_checkpointing", [])
group_delimiters = (
all_group_delimiters[i] if len(all_group_delimiters) > i else []
)
group_checkpoint = (
all_group_checkpoint[i] if len(all_group_checkpoint) > i else []
)
if not group_checkpoint:
group_checkpoint = [with_checkpointing] * len(group_delimiters)
assert len(group_delimiters) == len(group_checkpoint)
assert (
sorted(group_delimiters) == group_delimiters
), "Checkpoint boundaries should be sorted"
if not group_delimiters:
# No delimiters means one group but no activation checkpointing
# for this group (even if USE_ACTIVATION_CHECKPOINTING is set)
group_delimiters.append(depth)
group_checkpoint.append(False)
elif group_delimiters[-1] != depth:
# Complete missing checkpoints at the end (user can give only
# the intermediate checkpoints to avoid repetitions)
group_delimiters.append(depth)
group_checkpoint.append(with_checkpointing)
# Create the stage from the description of the block and the size of
# the block groups that compose this stage, then add it to the trunk
new_stage = factory.create_any_stage(
width_in=current_width,
width_out=width_out,
stride=stride,
depth=depth,
group_width=group_width,
bottleneck_multiplier=bottleneck_multiplier,
params=params,
stage_index=stage_index,
group_delimiters=group_delimiters,
group_checkpoint=group_checkpoint,
)
blocks.append((f"block{stage_index}", new_stage))
trunk_depth += blocks[-1][1].stage_depth
current_width = width_out
trunk_output = nn.Sequential(OrderedDict(blocks))
################################################################################
# Now map the models to the structure we want to expose for SSL tasks
# The upstream RegNet model is made of :
# - `stem`
# - n x blocks in trunk_output, named `block1, block2, ..`
# We're only interested in the stem and successive blocks
# everything else is not picked up on purpose
feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)]
for k, v in trunk_output.named_children():
assert k.startswith("block"), f"Unexpected layer name {k}"
block_index = len(feature_blocks) + 1
feature_blocks.append((f"res{block_index}", v))
feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1))))
feature_blocks.append(("flatten", Flatten(1)))
return nn.ModuleDict(feature_blocks), trunk_depth