def convert_weights_and_push()

in src/transformers/models/regnet/convert_regnet_to_pytorch.py [0:0]


def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True):
    filename = "imagenet-1k-id2label.json"
    num_labels = 1000
    expected_shape = (1, num_labels)

    repo_id = "huggingface/label-files"
    num_labels = num_labels
    id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
    id2label = {int(k): v for k, v in id2label.items()}

    id2label = id2label
    label2id = {v: k for k, v in id2label.items()}

    ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)

    names_to_config = {
        "regnet-x-002": ImageNetPreTrainedConfig(
            depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8, layer_type="x"
        ),
        "regnet-x-004": ImageNetPreTrainedConfig(
            depths=[1, 2, 7, 12], hidden_sizes=[32, 64, 160, 384], groups_width=16, layer_type="x"
        ),
        "regnet-x-006": ImageNetPreTrainedConfig(
            depths=[1, 3, 5, 7], hidden_sizes=[48, 96, 240, 528], groups_width=24, layer_type="x"
        ),
        "regnet-x-008": ImageNetPreTrainedConfig(
            depths=[1, 3, 7, 5], hidden_sizes=[64, 128, 288, 672], groups_width=16, layer_type="x"
        ),
        "regnet-x-016": ImageNetPreTrainedConfig(
            depths=[2, 4, 10, 2], hidden_sizes=[72, 168, 408, 912], groups_width=24, layer_type="x"
        ),
        "regnet-x-032": ImageNetPreTrainedConfig(
            depths=[2, 6, 15, 2], hidden_sizes=[96, 192, 432, 1008], groups_width=48, layer_type="x"
        ),
        "regnet-x-040": ImageNetPreTrainedConfig(
            depths=[2, 5, 14, 2], hidden_sizes=[80, 240, 560, 1360], groups_width=40, layer_type="x"
        ),
        "regnet-x-064": ImageNetPreTrainedConfig(
            depths=[2, 4, 10, 1], hidden_sizes=[168, 392, 784, 1624], groups_width=56, layer_type="x"
        ),
        "regnet-x-080": ImageNetPreTrainedConfig(
            depths=[2, 5, 15, 1], hidden_sizes=[80, 240, 720, 1920], groups_width=120, layer_type="x"
        ),
        "regnet-x-120": ImageNetPreTrainedConfig(
            depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112, layer_type="x"
        ),
        "regnet-x-160": ImageNetPreTrainedConfig(
            depths=[2, 6, 13, 1], hidden_sizes=[256, 512, 896, 2048], groups_width=128, layer_type="x"
        ),
        "regnet-x-320": ImageNetPreTrainedConfig(
            depths=[2, 7, 13, 1], hidden_sizes=[336, 672, 1344, 2520], groups_width=168, layer_type="x"
        ),
        # y variant
        "regnet-y-002": ImageNetPreTrainedConfig(depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8),
        "regnet-y-004": ImageNetPreTrainedConfig(
            depths=[1, 3, 6, 6], hidden_sizes=[48, 104, 208, 440], groups_width=8
        ),
        "regnet-y-006": ImageNetPreTrainedConfig(
            depths=[1, 3, 7, 4], hidden_sizes=[48, 112, 256, 608], groups_width=16
        ),
        "regnet-y-008": ImageNetPreTrainedConfig(
            depths=[1, 3, 8, 2], hidden_sizes=[64, 128, 320, 768], groups_width=16
        ),
        "regnet-y-016": ImageNetPreTrainedConfig(
            depths=[2, 6, 17, 2], hidden_sizes=[48, 120, 336, 888], groups_width=24
        ),
        "regnet-y-032": ImageNetPreTrainedConfig(
            depths=[2, 5, 13, 1], hidden_sizes=[72, 216, 576, 1512], groups_width=24
        ),
        "regnet-y-040": ImageNetPreTrainedConfig(
            depths=[2, 6, 12, 2], hidden_sizes=[128, 192, 512, 1088], groups_width=64
        ),
        "regnet-y-064": ImageNetPreTrainedConfig(
            depths=[2, 7, 14, 2], hidden_sizes=[144, 288, 576, 1296], groups_width=72
        ),
        "regnet-y-080": ImageNetPreTrainedConfig(
            depths=[2, 4, 10, 1], hidden_sizes=[168, 448, 896, 2016], groups_width=56
        ),
        "regnet-y-120": ImageNetPreTrainedConfig(
            depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112
        ),
        "regnet-y-160": ImageNetPreTrainedConfig(
            depths=[2, 4, 11, 1], hidden_sizes=[224, 448, 1232, 3024], groups_width=112
        ),
        "regnet-y-320": ImageNetPreTrainedConfig(
            depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232
        ),
        # models created by SEER -> https://huggingface.co/papers/2202.08360
        "regnet-y-320-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232),
        "regnet-y-640-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328),
        "regnet-y-1280-seer": RegNetConfig(
            depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264
        ),
        "regnet-y-2560-seer": RegNetConfig(
            depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640
        ),
        "regnet-y-10b-seer": ImageNetPreTrainedConfig(
            depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
        ),
        # finetuned on imagenet
        "regnet-y-320-seer-in1k": ImageNetPreTrainedConfig(
            depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232
        ),
        "regnet-y-640-seer-in1k": ImageNetPreTrainedConfig(
            depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328
        ),
        "regnet-y-1280-seer-in1k": ImageNetPreTrainedConfig(
            depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264
        ),
        "regnet-y-2560-seer-in1k": ImageNetPreTrainedConfig(
            depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640
        ),
        "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig(
            depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
        ),
    }

    names_to_ours_model_map = NameToOurModelFuncMap()
    names_to_from_model_map = NameToFromModelFuncMap()
    # add seer weights logic

    def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Module]) -> tuple[nn.Module, dict]:
        files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
        model = model_func()
        # check if we have a head, if yes add it
        model_state_dict = files["classy_state_dict"]["base_model"]["model"]
        state_dict = model_state_dict["trunk"]
        model.load_state_dict(state_dict)
        return model.eval(), model_state_dict["heads"]

    # pretrained
    names_to_from_model_map["regnet-y-320-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
    )

    names_to_from_model_map["regnet-y-640-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
    )

    names_to_from_model_map["regnet-y-1280-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
    )

    names_to_from_model_map["regnet-y-10b-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
        lambda: FakeRegNetVisslWrapper(
            RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
        ),
    )

    # IN1K finetuned
    names_to_from_model_map["regnet-y-320-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
    )

    names_to_from_model_map["regnet-y-640-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
    )

    names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
    )

    names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
        lambda: FakeRegNetVisslWrapper(
            RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
        ),
    )

    if model_name:
        convert_weight_and_push(
            model_name,
            names_to_from_model_map[model_name],
            names_to_ours_model_map[model_name],
            names_to_config[model_name],
            save_directory,
            push_to_hub,
        )
    else:
        for model_name, config in names_to_config.items():
            convert_weight_and_push(
                model_name,
                names_to_from_model_map[model_name],
                names_to_ours_model_map[model_name],
                config,
                save_directory,
                push_to_hub,
            )
    return config, expected_shape