contrib/utils.py (30 lines of code) (raw):
import contextlib
from typing import Generator
from unittest.mock import patch
@contextlib.contextmanager
def production_endpoint() -> Generator:
"""Patch huggingface_hub to connect to production server in a context manager.
Ugly way to patch all constants at once.
TODO: refactor when https://github.com/huggingface/huggingface_hub/issues/1172 is fixed.
Example:
```py
def test_push_to_hub():
# Pull from production Hub
with production_endpoint():
model = ...from_pretrained("modelname")
# Push to staging Hub
model.push_to_hub()
```
"""
PROD_ENDPOINT = "https://huggingface.co"
ENDPOINT_TARGETS = [
"huggingface_hub.constants",
"huggingface_hub._commit_api",
"huggingface_hub.hf_api",
"huggingface_hub.lfs",
"huggingface_hub.commands.user",
"huggingface_hub.utils._git_credential",
]
PROD_URL_TEMPLATE = PROD_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
URL_TEMPLATE_TARGETS = [
"huggingface_hub.constants",
"huggingface_hub.file_download",
]
from huggingface_hub.hf_api import api
patchers = (
[patch(target + ".ENDPOINT", PROD_ENDPOINT) for target in ENDPOINT_TARGETS]
+ [patch(target + ".HUGGINGFACE_CO_URL_TEMPLATE", PROD_URL_TEMPLATE) for target in URL_TEMPLATE_TARGETS]
+ [patch.object(api, "endpoint", PROD_URL_TEMPLATE)]
)
# Start all patches
for patcher in patchers:
patcher.start()
yield
# Stop all patches
for patcher in patchers:
patcher.stop()