def main()

in extra_scripts/convert_vissl_to_classy_vision.py [0:0]


def main():
    parser = argparse.ArgumentParser(
        description="Convert SSL framework RN50 models to Classy Vision models"
    )
    parser.add_argument(
        "--depth", type=int, default=50, help="Depth of the ResNet model to convert"
    )
    parser.add_argument(
        "--num_fc_layers",
        type=int,
        default=1,
        help="Number of linear layers in the head",
    )
    parser.add_argument(
        "--include_heads",
        default=False,
        action="store_true",
        help="Whether to convert the head as well or not",
    )
    parser.add_argument(
        "--use_bn_head",
        default=False,
        action="store_true",
        help="Whether BN is in the head or not",
    )
    parser.add_argument(
        "--use_bias_head_fc",
        default=False,
        action="store_true",
        help="Whether FC layers in head have bias param or not",
    )
    parser.add_argument(
        "--output_head_prefix",
        type=str,
        default="embedding",
        help="The [output_prefix]_fc.weight for the heads",
    )
    parser.add_argument(
        "--input_model_file",
        type=str,
        default=None,
        help="Path to input model weights to be converted",
    )
    parser.add_argument(
        "--output_model",
        type=str,
        default=None,
        help="Path to save converted RN-50 model",
    )
    parser.add_argument(
        "--state_dict_key_name",
        type=str,
        default="classy_state_dict",
        help="Pytorch model state_dict key name",
    )
    args = parser.parse_args()

    # load the input model weights
    logger.info("Loading weights...")
    state_dict = torch.load(args.input_model_file)
    assert (
        args.state_dict_key_name in state_dict
    ), f"{args.state_dict_key_name} not found"
    state_dict = state_dict[args.state_dict_key_name]

    if args.state_dict_key_name == "classy_state_dict":
        state_dict_head = state_dict["base_model"]["model"]["heads"]
        state_dict = state_dict["base_model"]["model"]["trunk"]
    else:
        assert not args.include_heads, "Can't convert heads"

    converted_trunk = convert_trunk_to_classy_model(state_dict, args.depth)
    if args.include_heads:
        converted_heads = convert_heads_to_classy_model(
            state_dict_head,
            args.output_head_prefix,
            args.num_fc_layers,
            args.use_bn_head,
            args.use_bias_head_fc,
        )
        output_state_dict = {
            "classy_state_dict": {
                "base_model": {
                    "model": {"trunk": converted_trunk, "heads": converted_heads}
                }
            }
        }
    else:
        output_state_dict = {
            "classy_state_dict": {"base_model": {"model": {"trunk": converted_trunk}}}
        }
    logger.info("Saving converted weights to: {}".format(args.output_model))
    torch.save(output_state_dict, args.output_model)
    logger.info("Done!!")