dialogflow-cx/vpc-sc-demo/backend/get_token.py (162 lines of code) (raw):
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to get a stored token from the VPC-SC Demo Auth Server."""
import base64
import collections
import io
import json
import logging
import uuid
import zipfile
import flask
import requests
from Crypto import Random
from Crypto.Cipher import AES, PKCS1_OAEP
from Crypto.PublicKey import RSA
from google.auth.transport import requests as reqs
from google.oauth2 import id_token
from session_blueprint import AUTH_SERVICE_HOSTNAME
logger = logging.getLogger(__name__)
PRIVATE_PEM_FILENAME = "private_key.pem"
class LruCache: # pylint: disable=too-few-public-methods
"""Quick implementation of an LRU cache."""
def __init__(self, func, max_size=128):
self.cache = collections.OrderedDict()
self.func = func
self.max_size = max_size
def __call__(self, *args):
cache = self.cache
if args in cache:
cache.move_to_end(args)
return cache[args]
result = self.func(*args)
cache[args] = result
if len(cache) > self.max_size:
cache.popitem(last=False)
return result
class AESCipher:
"""Organizes AES encryption methods."""
def __init__(self, key=None, block_size=16):
self.key = uuid.uuid4().hex.encode() if key is None else key
self.block_size = block_size
def pad(self, cstr):
"""Pad message if needed."""
return cstr + (self.block_size - len(cstr) % self.block_size) * chr(
self.block_size - len(cstr) % self.block_size
)
def unpad(self, cstr):
"""Unpad padded message."""
return cstr[: -ord(cstr[len(cstr) - 1 :])] # noqa: E203
def encrypt(self, raw):
"""Encrypt plaintext."""
raw = self.pad(raw).encode()
init_vec = Random.new().read(AES.block_size)
cipher = AES.new(self.key, AES.MODE_CBC, init_vec)
return base64.b64encode(init_vec + cipher.encrypt(raw))
def decrypt(self, enc):
"""Decrypt cyphertext."""
enc = base64.b64decode(enc)
init_vec = enc[:16]
cipher = AES.new(self.key, AES.MODE_CBC, init_vec)
return self.unpad(cipher.decrypt(enc[16:]))
def get_token_from_auth_server(session_id, auth_service_hostname=AUTH_SERVICE_HOSTNAME):
"""Retrieve a stored token from the VPC-SC Demo Auth Server."""
auth_service_auth_endpoint = f"http://{auth_service_hostname}/auth"
params = {
"session_id": session_id,
}
req = requests.get(auth_service_auth_endpoint, params=params, timeout=10)
if req.status_code == 401:
logger.error(
" auth-service %s rejected request: %s",
auth_service_auth_endpoint,
req.text,
)
return {
"response": flask.Response(
status=200,
response=json.dumps(
{"status": "BLOCKED", "reason": "REJECTED_REQUEST"}
),
)
}
with open(PRIVATE_PEM_FILENAME, "r", encoding="utf8") as file_handle:
private_pem = file_handle.read()
with zipfile.ZipFile(io.BytesIO(req.content)) as zip_file:
with zip_file.open("key") as curr_zip:
key_bytes_stream = curr_zip.read()
with zip_file.open("session_data") as curr_zip:
session_data_bytes_stream = curr_zip.read()
try:
decrypt = PKCS1_OAEP.new(key=RSA.import_key(private_pem))
decrypted_message = decrypt.decrypt(key_bytes_stream)
aes_cipher = AESCipher(key=decrypted_message)
return {
"auth_data": json.loads(
aes_cipher.decrypt(session_data_bytes_stream).decode()
)
}
except ValueError as exc:
logger.error("Decryption Error: %s", exc)
return {
"response": flask.Response(
status=200,
response=json.dumps(
{"status": "BLOCKED", "reason": "DECRYPTION_ERROR"}
),
)
}
def get_token(
request, token_type="access_token", cache=LruCache(get_token_from_auth_server)
):
"""Get a stored token from the VPC-SC Demo Auth Server, or from local cache."""
if not request.cookies.get("session_id"):
logger.info("get_token request did not have a session_id")
return {
"response": flask.Response(
status=200,
response=json.dumps({"status": "BLOCKED", "reason": "BAD_SESSION_ID"}),
)
}
session_id = request.cookies.get("session_id")
response = cache(session_id)
if "response" in response:
cache.cache.pop(session_id, None)
return response
auth_data = response["auth_data"]
try:
info = id_token.verify_oauth2_token(auth_data["id_token"], reqs.Request())
except ValueError as exc:
if "Token expired" in str(exc):
logger.info(" auth-service token expired")
return {
"response": flask.Response(
status=200,
response=json.dumps(
{"status": "BLOCKED", "reason": "TOKEN_EXPIRED"}
),
)
}
return {
"response": flask.Response(
status=200,
response=json.dumps({"status": "BLOCKED", "reason": "UNKNOWN"}),
)
}
if not info["email_verified"]:
logger.info(" oauth error: email not verified")
return {
"response": flask.Response(
status=200,
response=json.dumps({"status": "BLOCKED", "reason": "BAD_EMAIL"}),
)
}
response = {}
if token_type == "access_token":
response["access_token"] = auth_data["access_token"]
elif token_type == "id_token":
response["id_token"] = auth_data["id_token"]
elif token_type == "email":
response["email"] = auth_data["email"]
else:
response = (
f' Requested token_type "{token_type}" not one of '
'["access_token","id_token","email"]'
)
logger.info(response)
response = {
"response": flask.Response(
status=200,
response=json.dumps({"status": "BLOCKED", "reason": response.lstrip()}),
)
}
return response