in src/autotrain/app/oauth.py [0:0]
def _add_oauth_routes(app: fastapi.FastAPI) -> None:
"""
Add OAuth routes to the FastAPI app (login, callback handler, and logout).
This function performs the following tasks:
1. Checks for required environment variables and raises a ValueError if any are missing.
2. Registers the OAuth server with the provided client ID, client secret, scopes, and OpenID provider URL.
3. Defines the following OAuth routes:
- `/login/huggingface`: Redirects to the Hugging Face OAuth page.
- `/auth`: Handles the OAuth callback and manages the OAuth state.
Args:
app (fastapi.FastAPI): The FastAPI application instance to which the OAuth routes will be added.
Raises:
ValueError: If any of the required environment variables (OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET,
OAUTH_SCOPES, OPENID_PROVIDER_URL) are not set.
"""
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
# Check environment variables
msg = (
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
" setting `hf_oauth: true` in the Space metadata."
)
if OAUTH_CLIENT_ID is None:
raise ValueError(msg.format("OAUTH_CLIENT_ID"))
if OAUTH_CLIENT_SECRET is None:
raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
if OAUTH_SCOPES is None:
raise ValueError(msg.format("OAUTH_SCOPES"))
if OPENID_PROVIDER_URL is None:
raise ValueError(msg.format("OPENID_PROVIDER_URL"))
# Register OAuth server
oauth = OAuth()
oauth.register(
name="huggingface",
client_id=OAUTH_CLIENT_ID,
client_secret=OAUTH_CLIENT_SECRET,
client_kwargs={"scope": OAUTH_SCOPES},
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
)
# Define OAuth routes
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""
Handles the OAuth login process by redirecting to the Hugging Face OAuth page.
Args:
request (fastapi.Request): The incoming HTTP request.
Returns:
Response: A redirection response to the Hugging Face OAuth authorization page.
"""
"""Endpoint that redirects to HF OAuth page."""
redirect_uri = request.url_for("auth")
redirect_uri_as_str = str(redirect_uri)
if redirect_uri.netloc.endswith(".hf.space"):
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
return await oauth.huggingface.authorize_redirect(request, redirect_uri_as_str) # type: ignore
@app.get("/auth")
async def auth(request: fastapi.Request) -> RedirectResponse:
"""
Handles the OAuth callback for Hugging Face authentication.
Args:
request (fastapi.Request): The incoming request object.
Returns:
RedirectResponse: A response object that redirects the user to the appropriate page.
Raises:
MismatchingStateError: If there is a state mismatch, likely due to a corrupted cookie.
In this case, the user is redirected to the login page after clearing the relevant session keys.
Notes:
- If the state mismatch occurs, it is likely due to a bug in authlib that causes the token to grow indefinitely
if the user tries to login repeatedly. Since cookies cannot exceed 4kb, the token will be truncated at some point,
resulting in a lost state. The workaround is to delete the cookie and redirect the user to the login page again.
- See https://github.com/lepture/authlib/issues/622 for more details.
"""
"""Endpoint that handles the OAuth callback."""
try:
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
except MismatchingStateError:
# If the state mismatch, it is very likely that the cookie is corrupted.
# There is a bug reported in authlib that causes the token to grow indefinitely if the user tries to login
# repeatedly. Since cookies cannot get bigger than 4kb, the token will be truncated at some point - hence
# losing the state. A workaround is to delete the cookie and redirect the user to the login page again.
# See https://github.com/lepture/authlib/issues/622 for more details.
login_uri = "/login/huggingface"
if "_target_url" in request.query_params:
login_uri += "?" + urllib.parse.urlencode( # Keep same _target_url as before
{"_target_url": request.query_params["_target_url"]}
)
for key in list(request.session.keys()):
# Delete all keys that are related to the OAuth state
if key.startswith("_state_huggingface"):
request.session.pop(key)
return RedirectResponse(login_uri)
request.session["oauth_info"] = oauth_info
return _redirect_to_target(request)