airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py (144 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import os
import time
import webbrowser
import jwt
import requests
class DeviceFlowAuthenticator:
idp_url: str
realm: str
client_id: str
interval: int
device_code: str | None
_access_token: str | None
_refresh_token: str | None
def __has_expired__(self, token: str) -> bool:
try:
decoded = jwt.decode(token, options={"verify_signature": False})
tA = datetime.datetime.now(datetime.timezone.utc).timestamp()
tB = int(decoded.get("exp", 0))
return tA >= tB
except:
return True
@property
def access_token(self) -> str:
if self._access_token and not self.__has_expired__(self._access_token):
return self._access_token
elif self._refresh_token and not self.__has_expired__(self._refresh_token):
self.refresh()
else:
self.login()
assert self._access_token
return self._access_token
@property
def refresh_token(self) -> str:
if self._refresh_token and not self.__has_expired__(self._refresh_token):
return self._refresh_token
else:
self.login()
assert self._refresh_token
return self._refresh_token
def __init__(
self,
idp_url: str,
realm: str,
client_id: str,
):
self.idp_url = idp_url
self.realm = realm
self.client_id = client_id
if not self.client_id or not self.realm or not self.idp_url:
raise ValueError(
"Missing required environment variables for client ID, realm, or auth server URL")
self.interval = 5
self.device_code = None
self._access_token = None
self._refresh_token = None
def refresh(self) -> None:
auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token"
response = requests.post(auth_device_url, data={
"client_id": self.client_id,
"grant_type": "refresh_token",
"scope": "openid",
"refresh_token": self._refresh_token
})
if response.status_code != 200:
raise Exception(f"Error in token refresh request: {response.status_code} - {response.text}")
data = response.json()
self._refresh_token = data["refresh_token"]
self._access_token = data["access_token"]
assert self._access_token is not None
assert self._refresh_token is not None
self.__persist_token__(self._refresh_token, self._access_token)
def login(self, interactive: bool = True) -> None:
auth_warning = None
try:
# [Flow A] Reuse saved token
if os.path.exists("auth.state"):
try:
# [A1] Load token from file
with open("auth.state", "r") as f:
data = json.load(f)
self._refresh_token = str(data["refresh_token"])
self._access_token = str(data["access_token"])
except:
auth_warning = "Failed to load auth.state file!"
else:
# [A2] Check if access token is valid, if so, return
if not self.__has_expired__(self._access_token):
return print("Authenticated via saved access token!")
else:
auth_warning = "Access token is invalid!"
# [A3] Check if refresh token is valid. if so, refresh
try:
if not self.__has_expired__(self._refresh_token):
self.refresh()
return print("Authenticated via saved refresh token!")
else:
auth_warning = "Refresh token is invalid!"
except Exception as e:
print(*e.args)
if auth_warning:
print(auth_warning)
# [Flow B] Request device and user code
# [B1] Initiate device auth flow
auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/auth/device"
response = requests.post(auth_device_url, data={
"client_id": self.client_id,
"scope": "openid",
})
if response.status_code != 200:
raise Exception(f"Error in device authorization request: {response.status_code} - {response.text}")
data = response.json()
self.device_code = data.get("device_code", self.device_code)
self.interval = data.get("interval", self.interval)
url = data['verification_uri_complete']
print(f"Please authenticate by visiting: {url}")
if interactive:
webbrowser.open(url)
# [B2] Poll until token is received
token_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token"
print("Waiting for authorization...")
while True:
response = requests.post(
token_url,
data={
"client_id": self.client_id,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": self.device_code,
},
)
if response.status_code == 200:
data = response.json()
self.__persist_token__(data["refresh_token"], data["access_token"])
print("Authenticated via device auth!")
return
elif response.status_code == 400 and response.json().get("error") == "authorization_pending":
time.sleep(self.interval)
else:
raise Exception(f"Authorization error: {response.status_code} - {response.text}")
except Exception as e:
print("login() failed!", e)
def logout(self) -> None:
self._access_token = None
self._refresh_token = None
def __persist_token__(self, refresh_token: str, access_token: str) -> None:
self._access_token = access_token
self._refresh_token = refresh_token
import json
with open("auth.state", "w") as f:
json.dump({"refresh_token": self._refresh_token, "access_token": self._access_token}, f)