def create_regnet_feature_blocks()

in vissl/models/trunks/regnet_fsdp.py [0:0]


def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config):
    assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
    trunk_config = model_config.TRUNK.REGNET
    if "name" in trunk_config:
        assert (
            trunk_config["name"] == "anynet"
        ), "Please use AnyNetParams or specify RegNetParams dictionary"

    if "name" in trunk_config and trunk_config["name"] == "anynet":
        params = AnyNetParams(
            depths=trunk_config["depths"],
            widths=trunk_config["widths"],
            group_widths=trunk_config["group_widths"],
            bottleneck_multipliers=trunk_config["bottleneck_multipliers"],
            strides=trunk_config["strides"],
            stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[
                trunk_config.get("block_type", "res_bottleneck_block").upper()
            ],
            activation=ActivationType[trunk_config.get("activation", "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )
    else:
        params = RegNetParams(
            depth=trunk_config["depth"],
            w_0=trunk_config["w_0"],
            w_a=trunk_config["w_a"],
            w_m=trunk_config["w_m"],
            group_width=trunk_config["group_width"],
            bottleneck_multiplier=trunk_config.get("bottleneck_multiplier", 1.0),
            stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[
                trunk_config.get("block_type", "res_bottleneck_block").upper()
            ],
            activation=ActivationType[trunk_config.get("activation", "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )

    # Ad hoc stem
    #
    # Important: do NOT retain modules in self.stem or self.trunk_output. It may
    # seem to be harmless, but it appears that autograd will result in computing
    # grads in different order. Different ordering can cause deterministic OOM,
    # even when the peak memory otherwise is only 24GB out of 32GB.
    #
    # When debugging this, it is not enough to just dump the total module
    # params. You need to diff the module string representations.
    stem = factory.create_stem(params)

    # Instantiate all the AnyNet blocks in the trunk
    current_width, trunk_depth, blocks = params.stem_width, 0, []
    for i, (width_out, stride, depth, group_width, bottleneck_multiplier) in enumerate(
        params.get_expanded_params()
    ):
        # Starting from 1
        stage_index = i + 1

        # Identify where the block groups start and end, and whether they should
        # be surrounded by activation checkpoints
        # A block group is a group of block that is surrounded by a FSDP wrapper
        # and optionally an activation checkpoint wrapper
        with_checkpointing = (
            model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING
        )
        all_group_delimiters = trunk_config.get("stage_checkpoints", [])
        all_group_checkpoint = trunk_config.get("stage_checkpointing", [])
        group_delimiters = (
            all_group_delimiters[i] if len(all_group_delimiters) > i else []
        )
        group_checkpoint = (
            all_group_checkpoint[i] if len(all_group_checkpoint) > i else []
        )
        if not group_checkpoint:
            group_checkpoint = [with_checkpointing] * len(group_delimiters)
        assert len(group_delimiters) == len(group_checkpoint)

        assert (
            sorted(group_delimiters) == group_delimiters
        ), "Checkpoint boundaries should be sorted"
        if not group_delimiters:
            # No delimiters means one group but no activation checkpointing
            # for this group (even if USE_ACTIVATION_CHECKPOINTING is set)
            group_delimiters.append(depth)
            group_checkpoint.append(False)
        elif group_delimiters[-1] != depth:
            # Complete missing checkpoints at the end (user can give only
            # the intermediate checkpoints to avoid repetitions)
            group_delimiters.append(depth)
            group_checkpoint.append(with_checkpointing)

        # Create the stage from the description of the block and the size of
        # the block groups that compose this stage, then add it to the trunk
        new_stage = factory.create_any_stage(
            width_in=current_width,
            width_out=width_out,
            stride=stride,
            depth=depth,
            group_width=group_width,
            bottleneck_multiplier=bottleneck_multiplier,
            params=params,
            stage_index=stage_index,
            group_delimiters=group_delimiters,
            group_checkpoint=group_checkpoint,
        )
        blocks.append((f"block{stage_index}", new_stage))
        trunk_depth += blocks[-1][1].stage_depth
        current_width = width_out

    trunk_output = nn.Sequential(OrderedDict(blocks))

    ################################################################################

    # Now map the models to the structure we want to expose for SSL tasks
    # The upstream RegNet model is made of :
    # - `stem`
    # - n x blocks in trunk_output, named `block1, block2, ..`
    # We're only interested in the stem and successive blocks
    # everything else is not picked up on purpose
    feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)]
    for k, v in trunk_output.named_children():
        assert k.startswith("block"), f"Unexpected layer name {k}"
        block_index = len(feature_blocks) + 1
        feature_blocks.append((f"res{block_index}", v))
    feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1))))
    feature_blocks.append(("flatten", Flatten(1)))
    return nn.ModuleDict(feature_blocks), trunk_depth