cookbooks/aws-parallelcluster-platform/files/dcv/pcluster_dcv_authenticator.py (290 lines of code) (raw):
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file.
# This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import argparse
import errno
import hashlib
import json
import logging
import os
import random
import re
import ssl
import string
# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
import subprocess # nosec B404
import sys
import time
from collections import OrderedDict, namedtuple
from datetime import datetime, timedelta
from http.server import BaseHTTPRequestHandler, HTTPServer
from logging.handlers import RotatingFileHandler
from pwd import getpwuid
from socketserver import ThreadingMixIn
from urllib.parse import parse_qsl, urlparse
AUTHORIZATION_FILE_DIR = "/var/spool/parallelcluster/pcluster_dcv_authenticator"
LOG_FILE_PATH = "/var/log/parallelcluster/pcluster_dcv_authenticator.log"
logger = logging.getLogger(__name__)
def retry(func, func_args, attempts=1, wait=0):
"""
Call function and re-execute it if it raises an Exception.
:param func: the function to execute.
:param func_args: the positional arguments of the function.
:param attempts: the maximum number of attempts. Default: 1.
:param wait: delay between attempts. Default: 0.
:returns: the result of the function.
"""
while attempts:
try:
return func(*func_args)
except Exception as e:
attempts -= 1
if not attempts:
raise e
logger.info("%s, retrying in %s seconds..", e, wait)
time.sleep(wait)
# We should never reach this line, but the linters say otherwise.
return None
def generate_random_token(token_length):
"""Generate CSPRNG compliant random tokens."""
allowed_chars = "".join((string.ascii_letters, string.digits, "_", "-"))
max_int = len(allowed_chars) - 1
system_random = random.SystemRandom()
return "".join(allowed_chars[system_random.randint(0, max_int)] for _ in range(token_length))
class OneTimeTokenHandler:
"""
Store in memory tokens and information associated with them.
The handler maintains a limited number of tokens in memory with a FIFO logic when the limits are reached.
"""
def __init__(self, max_number_of_tokens):
self._tokens = OrderedDict()
self._max_number_of_tokens = max_number_of_tokens
def add_token(self, token, token_info):
"""
Add token and his corresponding information in the storage.
:param token the token to store
:param token_info a tuple of values associated to the token to store
"""
while len(self._tokens) >= self._max_number_of_tokens:
# Remove the first token stored
self._tokens.popitem(last=False)
self._tokens[token] = token_info
def get_token_info(self, token):
"""Pop the token and return the related information if the token is present, else returns None."""
return self._tokens.pop(token, None)
class DCVAuthenticator(BaseHTTPRequestHandler):
"""
Simple HTTP server to handle Amazon DCV authentication process.
The authentication process to access to a DCV session is performed by the following steps:
1. Obtain a Request Token:
- an user declares himself and asks for a Request Token for a given DCV Session:
- curl -X GET -G http://localhost:<port> -d action=requestToken -d authUser=<username> -d sessionID=<ID>
- the authenticator will return a json containing requestToken and accessFile values:
- the requestToken must be used as parameter for the Session Token request
- the accessFile is used to verify the user identity in the Session Token request
2. Obtain a DCV Session Token:
- the user must create an "access file" in the AUTHORIZATION_FILE_DIR, named as the retrieved accessFile value
- the user asks for a SessionToken (the real token to access to the DCV session)
- curl -X GET -G http://localhost:<port> -d action=sessionToken -d requestToken=<tr>
- the authenticator verifies the owner of the access file, the validity of the requestToken and returns
a Session Token
- the user can use the retrieved Session Token to connect to the DCV session.
3. DCV connection:
- the Session Token must be used in the web browser to access to the DCV Session
- the DCV process, running in the same instance of the authenticator, will ask to validate the token:
- curl -k http://localhost:<port> -d sessionId=<session-id> -d authenticationToken=<token>
- the authenticator verifies the validity of the authenticationToken and permits the user to access to the session.
"""
class IncorrectRequestError(Exception):
"""Class representing an incorrect request to the DCVAuthenticator."""
pass
USER_REGEX = r"^[a-z_]([a-z0-9_-]{0,31}|[a-z0-9_-]{0,30}\$)$"
SESSION_ID_REGEX = r"^([a-zA-Z0-9_-]{0,128})$"
# A nosec comment is appended to the following line in order to disable the B105 check.
# Since the TOKEN_REGEX is not a hardcoded password
TOKEN_REGEX = r"^([a-zA-Z0-9_-]{256})$" # nosec B105
MAX_NUMBER_OF_REQUEST_TOKENS = 500
MAX_NUMBER_OF_SESSION_TOKENS = 100
REQUEST_TOKEN_EXPIRE_SECONDS = 10
SESSION_TOKEN_EXPIRE_SECONDS = 30
# Define the information associated to a specific token
RequestTokenInfo = namedtuple("RequestTokenInfo", "user dcv_session_id creation_time access_file")
SessionTokenInfo = namedtuple("SessionTokenInfo", "user dcv_session_id creation_time")
# Define two token handlers with different capacity and expiration
request_token_manager = OneTimeTokenHandler(max_number_of_tokens=MAX_NUMBER_OF_REQUEST_TOKENS)
request_token_ttl = timedelta(seconds=REQUEST_TOKEN_EXPIRE_SECONDS)
session_token_manager = OneTimeTokenHandler(max_number_of_tokens=MAX_NUMBER_OF_SESSION_TOKENS)
session_token_ttl = timedelta(seconds=SESSION_TOKEN_EXPIRE_SECONDS)
def do_GET(self): # noqa N802, pylint: disable=C0103
"""
Handle GET requests coming from the user to obtain request and session tokens.
The format of the request should be:
curl -X GET -G http://localhost:<port> -d action=requestToken -d authUser=<username> -d sessionID=<ID>
curl -X GET -G http://localhost:<port> -d action=sessionToken -d requestToken=<tr>
"""
try:
logger.info("Validating user request..")
# validate number of parameters
parameters = dict(parse_qsl(urlparse(self.path).query))
if not parameters or len(parameters) > 3:
raise DCVAuthenticator.IncorrectRequestError(
f"Incorrect number of parameters passed.\nParameters: {parameters}"
)
# evaluate action parameter
action = self._extract_parameters_values(parameters, ["action"])[0]
if action == "requestToken":
username, session_id = self._extract_parameters_values(parameters, ["authUser", "sessionID"])
result = self._get_request_token(username, session_id)
elif action == "sessionToken":
request_token = self._extract_parameters_values(parameters, ["requestToken"])[0]
result = self._get_session_token(request_token)
else:
raise DCVAuthenticator.IncorrectRequestError(f"The action specified '{action}' is not valid.")
self._set_headers(400, content="application/json")
self.wfile.write(result.encode())
except DCVAuthenticator.IncorrectRequestError as e:
logger.error(e)
self._return_bad_request(e)
def do_POST(self): # noqa N802 pylint: disable=C0103
"""
Handle POST requests, coming from Amazon DCV server.
The format of the request is the following:
curl -k http://localhost:<port> -d sessionId=<session-id> -d authenticationToken=<token>
"""
try:
length = int(self.headers["Content-Length"])
field_data = self.rfile.read(length).decode("utf-8")
parameters = dict(parse_qsl(field_data))
if len(parameters) != 3:
raise DCVAuthenticator.IncorrectRequestError(
f"Incorrect number of parameters passed.\nParameters: {parameters}"
)
session_token, session_id = self._extract_parameters_values(
parameters, ["authenticationToken", "sessionId"]
)
authorized_user = self._check_auth(session_id, session_token)
if authorized_user:
self._return_auth_ok(username=authorized_user)
else:
raise DCVAuthenticator.IncorrectRequestError("The session token is not valid")
except DCVAuthenticator.IncorrectRequestError as e:
logger.error(e)
self._return_auth_ko(e)
def log_message(self, fmt, *args):
"""Override Server log message by removing authentication actions."""
if all(auth_action not in args[0] for auth_action in ["requestToken", "sectionToken"]):
logger.info(fmt, args)
def _set_headers(self, response, content="text/xml", length=None):
self.send_response(response)
self.send_header("Content-type", content)
if length:
self.send_header("Content-Length", length)
self.end_headers()
def _return_auth_ko(self, message):
http_string = f'<auth result="no"><message>{message}</message></auth>'
self._set_headers(200, length=len(http_string))
self.wfile.write(http_string.encode())
def _return_auth_ok(self, username):
http_string = f'<auth result="yes"><username>{username}</username></auth>'
self._set_headers(200, length=len(http_string))
self.wfile.write(http_string.format(username).encode())
def _return_bad_request(self, message):
self._set_headers(200)
self.wfile.write(f"{message}\n".encode())
@staticmethod
def _extract_parameters_values(parameters, keys):
try:
return [parameters[key] for key in keys]
except KeyError:
raise DCVAuthenticator.IncorrectRequestError(f"Wrong parameters. Required parameters are {', '.join(keys)}")
@classmethod
def _check_auth(cls, session_id, session_token):
"""Check session token expiration to see if it is still valid for the given DCV session id."""
# validate session and session token
DCVAuthenticator._validate_param(session_id, DCVAuthenticator.SESSION_ID_REGEX, "sessionId")
DCVAuthenticator._validate_param(session_token, DCVAuthenticator.TOKEN_REGEX, "sessionToken")
# search for token in the internal authenticator token storage
token_info = cls.session_token_manager.get_token_info(session_token)
if (
token_info
and token_info.dcv_session_id == session_id
and datetime.utcnow() - token_info.creation_time <= cls.session_token_ttl
):
return token_info.user
return None
@classmethod
def _get_request_token(cls, user, session_id):
"""
Obtain the request token and the "access file" name required to obtain the session token.
Generate a Request token, store in memory and returns a json containing the token itself
and the name of the file the user must create in the AUTHORIZATION_FILE_DIR.
"""
logger.info("New request for Request Token from user '%s' and DCV Session Id '%s'.", user, session_id)
# validate user and session
DCVAuthenticator._validate_param(user, DCVAuthenticator.USER_REGEX, "authUser")
DCVAuthenticator._validate_param(session_id, DCVAuthenticator.SESSION_ID_REGEX, "sessionId")
DCVAuthenticator._verify_session_existence(user, session_id)
logger.info("DCV session id and user are valid.")
# create and register internally a request token to use to retrieve the session token
logger.info("Generating new Request Token and Access File..")
request_token = generate_random_token(256)
access_file = generate_sha512_hash(request_token)
cls.request_token_manager.add_token(
request_token, DCVAuthenticator.RequestTokenInfo(user, session_id, datetime.utcnow(), access_file)
)
logger.info("Request Token and Access File generated correctly.")
return json.dumps({"requestToken": request_token, "accessFile": access_file})
@classmethod
def _get_session_token(cls, request_token):
"""
Obtain the session token to connect to the DCV session.
Generate a Session token, store in memory and returns a json containing the token itself.
"""
logger.info("New request for Session Token.")
DCVAuthenticator._validate_param(request_token, DCVAuthenticator.TOKEN_REGEX, "requestToken")
# retrieve request token information to validate it
logger.info("Validating Request Token..")
token_info = cls.request_token_manager.get_token_info(request_token)
if not token_info:
raise DCVAuthenticator.IncorrectRequestError("The requestToken parameter is not valid")
user = token_info.user
session_id = token_info.dcv_session_id
access_file = token_info.access_file
logger.info("Request Token is valid.")
# verify token expiration
logger.info("Verifying Request Token..")
if datetime.utcnow() - token_info.creation_time > cls.request_token_ttl:
raise DCVAuthenticator.IncorrectRequestError("The requestToken is not valid anymore")
logger.info("Request Token is valid.")
# verify user by checking if the access_file is created by the user asking the session token
logger.info("Verifying Access File..")
try:
access_file_path = f"{AUTHORIZATION_FILE_DIR}/{access_file}"
file_details = os.stat(access_file_path)
if getpwuid(file_details.st_uid).pw_name != user:
raise DCVAuthenticator.IncorrectRequestError("The user is not the one that created the access file")
if datetime.utcnow() - datetime.utcfromtimestamp(file_details.st_mtime) > cls.request_token_ttl:
raise DCVAuthenticator.IncorrectRequestError("The access file has expired")
logger.info("Access File is valid. User identified correctly.")
os.remove(access_file_path)
logger.info("Access File removed correctly.")
except OSError:
raise DCVAuthenticator.IncorrectRequestError("The Access File does not exist")
# create and register internally a session token
logger.info("Generating new Session Token..")
DCVAuthenticator._verify_session_existence(user, session_id)
session_token = generate_random_token(256)
cls.session_token_manager.add_token(
session_token, DCVAuthenticator.SessionTokenInfo(user, session_id, datetime.utcnow())
)
logger.info("Session Token created successfully.")
return json.dumps({"sessionToken": session_token})
@staticmethod
def _validate_param(string_to_test, regex, resource_name):
if not re.match(regex, string_to_test):
raise DCVAuthenticator.IncorrectRequestError(f"The {resource_name} parameter is not valid")
@staticmethod
def _is_session_valid(user, session_id):
"""
Verify if the DCV session exists and the ownership.
# We are using ps aux to retrieve the list of sessions
# because currently DCV doesn't allow list-session to list all session even for non-root user.
# TODO change this method if DCV updates his behaviour.
"""
logger.info("Verifying Amazon DCV session validity..")
# Remove the first and the last because they are the heading and empty, respectively
# All commands and arguments in this subprocess call are built as literals
processes = subprocess.check_output(["/bin/ps", "aux"]).decode("utf-8").split("\n")[1:-1] # nosec B603
# Check the filter is empty
if not next(
filter(lambda process: DCVAuthenticator.check_dcv_process(process, user, session_id), processes), None
):
raise DCVAuthenticator.IncorrectRequestError("The given session does not exists")
logger.info("The Amazon DCV session is valid.")
@staticmethod
def _verify_session_existence(user, session_id):
retry(DCVAuthenticator._is_session_valid, func_args=[user, session_id], attempts=20, wait=1)
@staticmethod
def check_dcv_process(row, user, session_id):
"""Check if there is a dcvagent process running for the given user and for the given session_id."""
# row example:
# centos 63 0.0 0.0 4348844 3108 ?? Ss 23Jul19 2:32.46 /usr/libexec/dcv/dcvagent --mode full \
# --session-id mysession
# ubuntu 2949 0.3 0.4 860568 34328 ? Sl 20:10 0:18 /usr/lib/x86_64-linux-gnu/dcv/dcvagent --mode full \
# --session-id mysession
fields = row.split()
command_index = 10
session_name_index = 14
user_index = 0
return (
fields[command_index].endswith("/dcv/dcvagent")
and fields[user_index] == user
and fields[session_name_index] == session_id
)
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
"""Handle requests in a separate thread."""
def _run_server(port, certificate=None, key=None):
"""
Run Amazon DCV authenticator server on localhost.
The Amazon DCV authenticator server *must* run with an appropriate user.
:param port: the port in which you want to start the server
:param certificate: the certificate to use if HTTPSs
:param key: the private key to use if HTTPSs
"""
server_hostname = "localhost"
server_address = (server_hostname, port)
httpd = ThreadedHTTPServer(server_address, DCVAuthenticator)
ssl_context = ssl.SSLContext()
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
ssl_context.server_hostname = server_hostname
if certificate:
ssl_context.load_cert_chain(certfile=certificate, keyfile=key)
httpd.socket = ssl_context.wrap_socket(httpd.socket, server_side=True)
print(
f"Starting DCV external authenticator {'HTTPS' if certificate else 'HTTP'} server on port {port}, use "
f"<Ctrl-C> to stop "
)
httpd.serve_forever()
def _parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Execute the ParallelCluster DCV External Authenticator")
parser.add_argument("--port", help="The port in which you want to start the HTTP server", type=int)
parser.add_argument("--certificate", help="The certificate to use to run in HTTPS. It must be a .pem file")
parser.add_argument("--key", help="The private key of the certificate")
return parser.parse_args()
def generate_sha512_hash(*args):
"""Generate a salted sha512 of the given token."""
salt = generate_random_token(256)
hash_handler = hashlib.sha512()
for item in args, salt:
hash_handler.update(str(item).encode("utf-8"))
return hash_handler.hexdigest()
def _prepare_auth_folder():
"""Delete old authorization files."""
for access_file in os.listdir(AUTHORIZATION_FILE_DIR):
os.remove(os.path.join(AUTHORIZATION_FILE_DIR, access_file))
def fail(message):
"""
Print error message and exit(1).
:param message: message to print
"""
logger.error(message)
sys.exit(1)
def _config_logger():
"""
Define a logger for pcluster_dcv_authenticator.
:return: the logger
"""
try:
logfile = os.path.expanduser(LOG_FILE_PATH)
logdir = os.path.dirname(logfile)
os.makedirs(logdir)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(logdir):
pass
else:
print(f"Cannot create log file ({logfile}). Failed with exception: {e}")
sys.exit(1)
formatter = logging.Formatter("%(asctime)s %(levelname)s [%(module)s:%(funcName)s] %(message)s")
logfile_handler = RotatingFileHandler(logfile, maxBytes=5 * 1024 * 1024, backupCount=1)
logfile_handler.setFormatter(formatter)
dcv_authenticator_logger = logging.getLogger("pcluster_dcv_authenticator")
dcv_authenticator_logger.addHandler(logfile_handler)
dcv_authenticator_logger.setLevel("INFO")
return dcv_authenticator_logger
def main():
global logger # pylint: disable=C0103,W0603
logger = _config_logger()
try:
logger.info("Starting Amazon DCV authenticator server")
args = _parse_args()
_prepare_auth_folder()
_run_server(port=args.port if args.port else 8444, certificate=args.certificate, key=args.key)
except KeyboardInterrupt:
logger.info("Closing Amazon DCV authenticator server")
except Exception as e:
fail(f"Unexpected error of type {type(e).__name__}: {e}")
if __name__ == "__main__":
main()