def _rename_weights_for_resnet()

in src/MaskRCNNDetection/maskrcnn_benchmark/utils/c2_model_loading.py [0:0]


def _rename_weights_for_resnet(weights, stage_names):
    original_keys = sorted(weights.keys())
    layer_keys = sorted(weights.keys())

    # for X-101, rename output to fc1000 to avoid conflicts afterwards
    layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys]
    layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys]

    # performs basic renaming: _ -> . , etc
    layer_keys = _rename_basic_resnet_weights(layer_keys)

    # FPN
    layer_keys = _rename_fpn_weights(layer_keys, stage_names)

    # Mask R-CNN
    layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys]
    layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys]
    layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys]

    # Keypoint R-CNN
    layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys]
    layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys]
    layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys]

    # Rename for our RPN structure
    layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys]

    key_map = {k: v for k, v in zip(original_keys, layer_keys)}

    logger = logging.getLogger(__name__)
    logger.info("Remapping C2 weights")
    max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k])

    new_weights = OrderedDict()
    for k in original_keys:
        v = weights[k]
        if "_momentum" in k:
            continue
        # if 'fc1000' in k:
        #     continue
        w = torch.from_numpy(v)
        # if "bn" in k:
        #     w = w.view(1, -1, 1, 1)
        logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k]))
        new_weights[key_map[k]] = w

    return new_weights