def convert_convnextv2_checkpoint()

in src/transformers/models/convnextv2/convert_convnextv2_to_pytorch.py [0:0]


def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
    """
    Copy/paste/tweak model's weights to our ConvNeXTV2 structure.
    """
    print("Downloading original model from checkpoint...")
    # define ConvNeXTV2 configuration based on URL
    config, expected_shape = get_convnextv2_config(checkpoint_url)
    # load original state_dict from URL
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]

    print("Converting model parameters...")
    # rename keys
    for key in state_dict.copy().keys():
        val = state_dict.pop(key)
        state_dict[rename_key(key)] = val
    # add prefix to all keys expect classifier head
    for key in state_dict.copy().keys():
        val = state_dict.pop(key)
        if not key.startswith("classifier"):
            key = "convnextv2." + key
        state_dict[key] = val

    # load HuggingFace model
    model = ConvNextV2ForImageClassification(config)
    model.load_state_dict(state_dict)
    model.eval()

    # Check outputs on an image, prepared by ConvNextImageProcessor
    preprocessor = convert_preprocessor(checkpoint_url)
    inputs = preprocessor(images=prepare_img(), return_tensors="pt")
    logits = model(**inputs).logits

    # note: the logits below were obtained without center cropping
    if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt":
        expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt":
        expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt":
        expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt":
        expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt":
        expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt":
        expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt":
        expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt":
        expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt":
        expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt":
        expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt":
        expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309])
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt":
        expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826])
    else:
        raise ValueError(f"Unknown URL: {checkpoint_url}")

    assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3)
    assert logits.shape == expected_shape
    print("Model outputs match the original results!")

    if save_model:
        print("Saving model to local...")
        # Create folder to save model
        if not os.path.isdir(pytorch_dump_folder_path):
            os.mkdir(pytorch_dump_folder_path)

        model.save_pretrained(pytorch_dump_folder_path)
        preprocessor.save_pretrained(pytorch_dump_folder_path)

    model_name = "convnextv2"
    if "atto" in checkpoint_url:
        model_name += "-atto"
    if "femto" in checkpoint_url:
        model_name += "-femto"
    if "pico" in checkpoint_url:
        model_name += "-pico"
    if "nano" in checkpoint_url:
        model_name += "-nano"
    elif "tiny" in checkpoint_url:
        model_name += "-tiny"
    elif "base" in checkpoint_url:
        model_name += "-base"
    elif "large" in checkpoint_url:
        model_name += "-large"
    elif "huge" in checkpoint_url:
        model_name += "-huge"
    if "22k" in checkpoint_url and "1k" not in checkpoint_url:
        model_name += "-22k"
    elif "22k" in checkpoint_url and "1k" in checkpoint_url:
        model_name += "-22k-1k"
    elif "1k" in checkpoint_url:
        model_name += "-1k"
    if "224" in checkpoint_url:
        model_name += "-224"
    elif "384" in checkpoint_url:
        model_name += "-384"
    elif "512" in checkpoint_url:
        model_name += "-512"

    if push_to_hub:
        print(f"Pushing {model_name} to the hub...")
        model.push_to_hub(model_name)
        preprocessor.push_to_hub(model_name)