atr/ssh.py (358 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """SSH server module for ATR.""" import asyncio import asyncio.subprocess import datetime import logging import os import string from typing import Final, TypeVar import aiofiles import aiofiles.os import asyncssh import atr.config as config import atr.db as db import atr.db.models as models import atr.revision as revision import atr.user as user import atr.util as util _LOGGER: Final = logging.getLogger(__name__) _CONFIG: Final = config.get() T = TypeVar("T") class SSHServer(asyncssh.SSHServer): """Simple SSH server that handles connections.""" def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: """Called when a connection is established.""" # Store connection for use in begin_auth self._conn = conn peer_addr = conn.get_extra_info("peername")[0] _LOGGER.info(f"SSH connection received from {peer_addr}") def connection_lost(self, exc: Exception | None) -> None: """Called when a connection is lost or closed.""" if exc: _LOGGER.error(f"SSH connection error: {exc}") else: _LOGGER.info("SSH connection closed") async def begin_auth(self, username: str) -> bool: """Begin authentication for the specified user.""" _LOGGER.info(f"Beginning auth for user {username}") try: # Load SSH keys for this user from the database async with db.session() as data: user_keys = await data.ssh_key(asf_uid=username).all() if not user_keys: _LOGGER.warning(f"No SSH keys found for user: {username}") # Still require authentication, but it will fail return True # Create an authorized_keys file as a string auth_keys_lines = [] for user_key in user_keys: auth_keys_lines.append(user_key.key) auth_keys_data = "\n".join(auth_keys_lines) _LOGGER.info(f"Loaded {len(user_keys)} SSH keys for user {username}") # Set the authorized keys in the connection try: authorized_keys = asyncssh.import_authorized_keys(auth_keys_data) self._conn.set_authorized_keys(authorized_keys) _LOGGER.info(f"Successfully set authorized keys for {username}") except Exception as e: _LOGGER.error(f"Error setting authorized keys: {e}") except Exception as e: _LOGGER.error(f"Database error loading SSH keys: {e}") # Always require authentication return True def public_key_auth_supported(self) -> bool: """Indicate whether public key authentication is supported.""" return True async def server_start() -> asyncssh.SSHAcceptor: """Start the SSH server.""" # TODO: Where do we actually do this? # await aiofiles.os.makedirs(_CONFIG.STATE_DIR, exist_ok=True) # Generate temporary host key if it doesn't exist key_path = os.path.join(_CONFIG.STATE_DIR, "ssh_host_key") if not await aiofiles.os.path.exists(key_path): private_key = asyncssh.generate_private_key("ssh-rsa") private_key.write_private_key(key_path) _LOGGER.info(f"Generated SSH host key at {key_path}") server = await asyncssh.create_server( SSHServer, server_host_keys=[key_path], process_factory=_step_01_handle_client, host=_CONFIG.SSH_HOST, port=_CONFIG.SSH_PORT, encoding=None, ) _LOGGER.info(f"SSH server started on {_CONFIG.SSH_HOST}:{_CONFIG.SSH_PORT}") return server async def server_stop(server: asyncssh.SSHAcceptor) -> None: """Stop the SSH server.""" server.close() await server.wait_closed() _LOGGER.info("SSH server stopped") def _fail(process: asyncssh.SSHServerProcess, message: str, return_value: T) -> T: _LOGGER.error(message) # Ensure message is encoded before writing to stderr encoded_message = f"ATR SSH error: {message}\n".encode() try: process.stderr.write(encoded_message) except BrokenPipeError: _LOGGER.warning("Failed to write error to client stderr: Broken pipe") except Exception as e: _LOGGER.exception(f"Error writing to client stderr: {e}") process.exit(1) return return_value async def _step_01_handle_client(process: asyncssh.SSHServerProcess) -> None: """Process client command, validating and dispatching to read or write handlers.""" asf_uid = process.get_extra_info("username") _LOGGER.info(f"Handling command for authenticated user: {asf_uid}") if not process.command: return _fail(process, "No command specified", None) _LOGGER.info(f"Command received: {process.command}") # TODO: Use shlex.split or similar if commands can contain quoted arguments argv = process.command.split() ############################################## ### Calls _step_02_command_simple_validate ### ############################################## simple_validation_error, path_index, is_read_request = _step_02_command_simple_validate(argv) if simple_validation_error: return _fail(process, f"{simple_validation_error}\nCommand: {process.command}", None) ####################################### ### Calls _step_04_command_validate ### ####################################### validation_results = await _step_04_command_validate(process, argv, path_index, is_read_request) if not validation_results: return # Unpack results # The release object is only present for read requests project_name, version_name, release_obj = validation_results if is_read_request: if release_obj is None: # This should not happen if the validation logic is correct return _fail(process, "Internal error: Release object missing for read request after validation", None) _LOGGER.info(f"Processing READ request for {project_name}-{version_name}") #################################################### ### Calls _step_07a_process_validated_rsync_read ### #################################################### await _step_07a_process_validated_rsync_read(process, argv, path_index, release_obj) else: _LOGGER.info(f"Processing WRITE request for {project_name}-{version_name}") ##################################################### ### Calls _step_07b_process_validated_rsync_write ### ##################################################### await _step_07b_process_validated_rsync_write(process, argv, path_index, project_name, version_name) def _step_02_command_simple_validate(argv: list[str]) -> tuple[str | None, int, bool]: """Validate the basic structure of the rsync command and detect read vs write.""" # READ: ['rsync', '--server', '--sender', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/'] # WRITE: ['rsync', '--server', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/'] if not argv: return "Empty command", -1, False if argv[0] != "rsync": return "The first argument must be rsync", -1, False if argv[1] != "--server": return "The second argument must be --server", -1, False is_read_request = False option_index = 2 # Check for --sender flag, which indicates a read request if (len(argv) > 2) and (argv[2] == "--sender"): is_read_request = True option_index = 3 if len(argv) <= option_index: return "Missing options after --sender", -1, True elif len(argv) <= 2: return "Missing options argument", -1, False # Validate the options argument strictly options = argv[option_index] if "e." not in options: return "The options argument (after --sender) must contain 'e.'", -1, True # The options after -e. are compatibility flags and can be ignored if options.split("e.", 1)[0] != "-vlogDtpr": return "The options argument (after --sender) must be '-vlogDtpre.[compatibility flags]'", -1, True #################################################### ### Calls _step_03_validate_rsync_args_structure ### #################################################### error, path_index = _step_03_validate_rsync_args_structure(argv, option_index, is_read_request) if error: return error, -1, is_read_request return None, path_index, is_read_request def _step_03_validate_rsync_args_structure( argv: list[str], option_index: int, is_read_request: bool ) -> tuple[str | None, int]: """Validate the dot argument and path argument presence and count.""" # READ: ['rsync', '--server', '--sender', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/'] :: 3 :: True # WRITE: ['rsync', '--server', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/'] :: 2 :: False dot_arg_index = option_index + 1 path_index = option_index + 2 # Write requests might have --delete has_delete = False if (not is_read_request) and (len(argv) > dot_arg_index) and (argv[dot_arg_index] == "--delete"): has_delete = True dot_arg_index += 1 path_index += 1 if (len(argv) <= dot_arg_index) or (argv[dot_arg_index] != "."): expected_pos = "fourth" if (is_read_request or (not has_delete)) else "fifth" return f"The {expected_pos} argument must be .", -1 if len(argv) <= path_index: return "Missing path argument", -1 # Check expected total number of arguments expected_len = path_index + 1 if len(argv) != expected_len: return f"Expected {expected_len} arguments, but got {len(argv)}", -1 return None, path_index async def _step_04_command_validate( process: asyncssh.SSHServerProcess, argv: list[str], path_index: int, is_read_request: bool ) -> tuple[str, str, models.Release | None] | None: """Validate the path and user permissions for read or write.""" ############################################ ### Calls _step_05_command_path_validate ### ############################################ result = _step_05_command_path_validate(argv[path_index]) if isinstance(result, str): return _fail(process, result, None) path_project, path_version = result ssh_uid = process.get_extra_info("username") async with db.session() as data: project = await data.project(name=path_project, _committee=True).get() if project is None: # Projects are public, so existence information is public return _fail(process, f"Project '{path_project}' does not exist", None) release = await data.release(project_name=project.name, version=path_version).get() if is_read_request: ################################################# ### Calls _step_06a_validate_read_permissions ### ################################################# validated_release, success = await _step_06a_validate_read_permissions( process, ssh_uid, project, release, path_project, path_version ) if success is None: return None return path_project, path_version, validated_release else: ################################################## ### Calls _step_06b_validate_write_permissions ### ################################################## success = await _step_06b_validate_write_permissions(process, ssh_uid, project, release) if success is None: return None # Return None for the release object for write requests return path_project, path_version, None def _step_05_command_path_validate(path: str) -> tuple[str, str] | str: """Validate the path argument for rsync commands.""" # READ: rsync --server --sender -vlogDtpre.iLsfxCIvu . /proj/v1/ # Validating path: /proj/v1/ # WRITE: rsync --server -vlogDtpre.iLsfxCIvu . /proj/v1/ # Validating path: /proj/v1/ if not path.startswith("/"): return "The path argument should be an absolute path" if not path.endswith("/"): # Technically we could ignore this, because we rewrite the path anyway for writes # But we should enforce good rsync usage practices return "The path argument should be a directory path, ending with a /" if "//" in path: return "The path argument should not contain //" if path.count("/") != 3: return "The path argument should be a /PROJECT/VERSION/ directory path" path_project, path_version = path.strip("/").split("/", 1) alphanum = set(string.ascii_letters + string.digits + "-") if not all(c in alphanum for c in path_project): return "The project name should contain only alphanumeric characters or hyphens" # From a survey of version numbers we find that only . and - are used # We also allow + which is in common use version_punctuation = set(".-+") if path_version[0] not in alphanum: # Must certainly not allow the directory to be called "." or ".." # And we also want to avoid patterns like ".htaccess" return "The version should start with an alphanumeric character" if path_version[-1] not in alphanum: return "The version should end with an alphanumeric character" if not all(c in (alphanum | version_punctuation) for c in path_version): return "The version should contain only alphanumeric characters, dots, dashes, or pluses" return path_project, path_version async def _step_06a_validate_read_permissions( process: asyncssh.SSHServerProcess, ssh_uid: str, project: models.Project, release: models.Release | None, path_project: str, path_version: str, ) -> tuple[models.Release | None, bool]: """Validate permissions for a read request.""" if release is None: _fail(process, f"Release '{path_project}-{path_version}' does not exist", None) return None, False allowed_read_phases = { models.ReleasePhase.RELEASE_CANDIDATE_DRAFT, models.ReleasePhase.RELEASE_CANDIDATE, models.ReleasePhase.RELEASE_PREVIEW, } if release.phase not in allowed_read_phases: _fail(process, f"Release '{release.name}' is not in a readable phase ({release.phase.value})", None) return None, False if not user.is_committer(project.committee, ssh_uid): _fail( process, f"You must be a committer or committee member for project '{project.name}' to read this release", None, ) return None, False return release, True async def _step_06b_validate_write_permissions( process: asyncssh.SSHServerProcess, ssh_uid: str, project: models.Project, release: models.Release | None, ) -> bool: """Validate permissions for a write request.""" if release is None: # Creating a new release requires committee membership if not user.is_committee_member(project.committee, ssh_uid): return _fail( process, f"You must be a member of project '{project.name}' committee to create a release", False, ) else: # Uploading to existing release, requires DRAFT and participant status if release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT: return _fail( process, f"Cannot upload: Release '{release.name}' is no longer in draft phase ({release.phase.value})", False, ) if not user.is_committer(project.committee, ssh_uid): return _fail( process, f"You must be a committer or committee member for project '{project.name}' " "to upload to this draft release", False, ) return True async def _step_07a_process_validated_rsync_read( process: asyncssh.SSHServerProcess, argv: list[str], path_index: int, release: models.Release, ) -> None: """Handle a validated rsync read request.""" exit_status = 1 try: # Determine the source directory based on the release phase and revision source_dir = util.release_directory(release) _LOGGER.info( f"Identified source directory for read: {source_dir} for release " f"{release.name} (phase {release.phase.value})" ) # Check whether the source directory actually exists before proceeding if not await aiofiles.os.path.isdir(source_dir): return _fail(process, f"Source directory '{source_dir}' not found for release {release.name}", None) # Update the rsync command path to the determined source directory argv[path_index] = str(source_dir) if not argv[path_index].endswith("/"): argv[path_index] += "/" ################################################### ### Calls _step_08_execute_rsync_sender_command ### ################################################### exit_status = await _step_08_execute_rsync(process, argv) if exit_status != 0: _LOGGER.error( f"rsync --sender failed with exit status {exit_status} for release {release.name}. " f"Command: {process.command} (run as {' '.join(argv)})" ) process.exit(exit_status) except Exception as e: _LOGGER.exception(f"Error during rsync read processing for {release.name}") _fail(process, f"Internal error processing read request: {e}", None) process.exit(1) async def _step_07b_process_validated_rsync_write( process: asyncssh.SSHServerProcess, argv: list[str], path_index: int, project_name: str, version_name: str, ) -> None: """Handle a validated rsync write request.""" asf_uid = process.get_extra_info("username") exit_status = 1 try: # Ensure the release object exists or is created # This must happen before creating the revision directory ####################################################### ### Calls _step_07c_ensure_release_object_for_write ### ####################################################### if not await _step_07c_ensure_release_object_for_write(process, project_name, version_name): # The _fail function was already called in _07b2_ensure_release_object_for_write return # Create the draft revision directory structure async with revision.create_and_manage(project_name, version_name, asf_uid) as ( new_revision_dir, new_draft_revision, ): _LOGGER.info(f"Created draft revision directory: {new_revision_dir} ({new_draft_revision})") # Update the rsync command path to the new revision directory argv[path_index] = str(new_revision_dir) ################################################### ### Calls _step_08_execute_rsync_upload_command ### ################################################### exit_status = await _step_08_execute_rsync(process, argv) if exit_status != 0: _LOGGER.error( f"rsync upload failed with exit status {exit_status} for revision {new_draft_revision}. " f"Command: {process.command} (run as {' '.join(argv)})" ) _LOGGER.info(f"rsync upload successful for revision {new_draft_revision}") # Close the connection unconditionally # If we use "if not process.is_closing():" then it fails process.exit(exit_status) except Exception as e: _LOGGER.exception(f"Error during draft revision processing for {project_name}-{version_name}") _fail(process, f"Internal error processing upload revision: {e}", None) process.exit(1) async def _step_07c_ensure_release_object_for_write( process: asyncssh.SSHServerProcess, project_name: str, version_name: str ) -> bool: """Ensure the release object exists or create it for a write operation.""" try: async with db.session() as data: async with data.begin(): release = await data.release( name=models.release_name(project_name, version_name), _committee=True ).get() if release is None: project = await data.project(name=project_name, _committee=True).demand( RuntimeError("Project not found after validation") ) if version_name_error := util.version_name_error(version_name): # This should ideally be caught by path validation, but double check raise RuntimeError(f'Invalid version name "{version_name}": {version_name_error}') # Create a new release object _LOGGER.info(f"Creating new release object for {project_name}-{version_name}") release = models.Release( project_name=project.name, project=project, version=version_name, stage=models.ReleaseStage.RELEASE_CANDIDATE, phase=models.ReleasePhase.RELEASE_CANDIDATE_DRAFT, created=datetime.datetime.now(datetime.UTC), ) data.add(release) elif release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT: return _fail( process, f"Release '{release.name}' is no longer in draft phase ({release.phase.value}) " "- cannot create new revision", False, ) return True except Exception as e: _LOGGER.exception(f"Error ensuring release object for write: {project_name}-{version_name}") return _fail(process, f"Internal error ensuring release object: {e}", False) async def _step_08_execute_rsync(process: asyncssh.SSHServerProcess, argv: list[str]) -> int: """Execute the modified rsync command.""" _LOGGER.info(f"Executing modified rsync command: {' '.join(argv)}") proc = await asyncio.create_subprocess_shell( " ".join(argv), stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) # Redirect the client's streams to the rsync process await process.redirect(stdin=proc.stdin, stdout=proc.stdout, stderr=proc.stderr) # Wait for rsync to finish and get its exit status exit_status = await proc.wait() _LOGGER.info(f"Rsync finished with exit status {exit_status}") return exit_status