atr/ssh.py (358 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""SSH server module for ATR."""
import asyncio
import asyncio.subprocess
import datetime
import logging
import os
import string
from typing import Final, TypeVar
import aiofiles
import aiofiles.os
import asyncssh
import atr.config as config
import atr.db as db
import atr.db.models as models
import atr.revision as revision
import atr.user as user
import atr.util as util
_LOGGER: Final = logging.getLogger(__name__)
_CONFIG: Final = config.get()
T = TypeVar("T")
class SSHServer(asyncssh.SSHServer):
"""Simple SSH server that handles connections."""
def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
"""Called when a connection is established."""
# Store connection for use in begin_auth
self._conn = conn
peer_addr = conn.get_extra_info("peername")[0]
_LOGGER.info(f"SSH connection received from {peer_addr}")
def connection_lost(self, exc: Exception | None) -> None:
"""Called when a connection is lost or closed."""
if exc:
_LOGGER.error(f"SSH connection error: {exc}")
else:
_LOGGER.info("SSH connection closed")
async def begin_auth(self, username: str) -> bool:
"""Begin authentication for the specified user."""
_LOGGER.info(f"Beginning auth for user {username}")
try:
# Load SSH keys for this user from the database
async with db.session() as data:
user_keys = await data.ssh_key(asf_uid=username).all()
if not user_keys:
_LOGGER.warning(f"No SSH keys found for user: {username}")
# Still require authentication, but it will fail
return True
# Create an authorized_keys file as a string
auth_keys_lines = []
for user_key in user_keys:
auth_keys_lines.append(user_key.key)
auth_keys_data = "\n".join(auth_keys_lines)
_LOGGER.info(f"Loaded {len(user_keys)} SSH keys for user {username}")
# Set the authorized keys in the connection
try:
authorized_keys = asyncssh.import_authorized_keys(auth_keys_data)
self._conn.set_authorized_keys(authorized_keys)
_LOGGER.info(f"Successfully set authorized keys for {username}")
except Exception as e:
_LOGGER.error(f"Error setting authorized keys: {e}")
except Exception as e:
_LOGGER.error(f"Database error loading SSH keys: {e}")
# Always require authentication
return True
def public_key_auth_supported(self) -> bool:
"""Indicate whether public key authentication is supported."""
return True
async def server_start() -> asyncssh.SSHAcceptor:
"""Start the SSH server."""
# TODO: Where do we actually do this?
# await aiofiles.os.makedirs(_CONFIG.STATE_DIR, exist_ok=True)
# Generate temporary host key if it doesn't exist
key_path = os.path.join(_CONFIG.STATE_DIR, "ssh_host_key")
if not await aiofiles.os.path.exists(key_path):
private_key = asyncssh.generate_private_key("ssh-rsa")
private_key.write_private_key(key_path)
_LOGGER.info(f"Generated SSH host key at {key_path}")
server = await asyncssh.create_server(
SSHServer,
server_host_keys=[key_path],
process_factory=_step_01_handle_client,
host=_CONFIG.SSH_HOST,
port=_CONFIG.SSH_PORT,
encoding=None,
)
_LOGGER.info(f"SSH server started on {_CONFIG.SSH_HOST}:{_CONFIG.SSH_PORT}")
return server
async def server_stop(server: asyncssh.SSHAcceptor) -> None:
"""Stop the SSH server."""
server.close()
await server.wait_closed()
_LOGGER.info("SSH server stopped")
def _fail(process: asyncssh.SSHServerProcess, message: str, return_value: T) -> T:
_LOGGER.error(message)
# Ensure message is encoded before writing to stderr
encoded_message = f"ATR SSH error: {message}\n".encode()
try:
process.stderr.write(encoded_message)
except BrokenPipeError:
_LOGGER.warning("Failed to write error to client stderr: Broken pipe")
except Exception as e:
_LOGGER.exception(f"Error writing to client stderr: {e}")
process.exit(1)
return return_value
async def _step_01_handle_client(process: asyncssh.SSHServerProcess) -> None:
"""Process client command, validating and dispatching to read or write handlers."""
asf_uid = process.get_extra_info("username")
_LOGGER.info(f"Handling command for authenticated user: {asf_uid}")
if not process.command:
return _fail(process, "No command specified", None)
_LOGGER.info(f"Command received: {process.command}")
# TODO: Use shlex.split or similar if commands can contain quoted arguments
argv = process.command.split()
##############################################
### Calls _step_02_command_simple_validate ###
##############################################
simple_validation_error, path_index, is_read_request = _step_02_command_simple_validate(argv)
if simple_validation_error:
return _fail(process, f"{simple_validation_error}\nCommand: {process.command}", None)
#######################################
### Calls _step_04_command_validate ###
#######################################
validation_results = await _step_04_command_validate(process, argv, path_index, is_read_request)
if not validation_results:
return
# Unpack results
# The release object is only present for read requests
project_name, version_name, release_obj = validation_results
if is_read_request:
if release_obj is None:
# This should not happen if the validation logic is correct
return _fail(process, "Internal error: Release object missing for read request after validation", None)
_LOGGER.info(f"Processing READ request for {project_name}-{version_name}")
####################################################
### Calls _step_07a_process_validated_rsync_read ###
####################################################
await _step_07a_process_validated_rsync_read(process, argv, path_index, release_obj)
else:
_LOGGER.info(f"Processing WRITE request for {project_name}-{version_name}")
#####################################################
### Calls _step_07b_process_validated_rsync_write ###
#####################################################
await _step_07b_process_validated_rsync_write(process, argv, path_index, project_name, version_name)
def _step_02_command_simple_validate(argv: list[str]) -> tuple[str | None, int, bool]:
"""Validate the basic structure of the rsync command and detect read vs write."""
# READ: ['rsync', '--server', '--sender', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/']
# WRITE: ['rsync', '--server', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/']
if not argv:
return "Empty command", -1, False
if argv[0] != "rsync":
return "The first argument must be rsync", -1, False
if argv[1] != "--server":
return "The second argument must be --server", -1, False
is_read_request = False
option_index = 2
# Check for --sender flag, which indicates a read request
if (len(argv) > 2) and (argv[2] == "--sender"):
is_read_request = True
option_index = 3
if len(argv) <= option_index:
return "Missing options after --sender", -1, True
elif len(argv) <= 2:
return "Missing options argument", -1, False
# Validate the options argument strictly
options = argv[option_index]
if "e." not in options:
return "The options argument (after --sender) must contain 'e.'", -1, True
# The options after -e. are compatibility flags and can be ignored
if options.split("e.", 1)[0] != "-vlogDtpr":
return "The options argument (after --sender) must be '-vlogDtpre.[compatibility flags]'", -1, True
####################################################
### Calls _step_03_validate_rsync_args_structure ###
####################################################
error, path_index = _step_03_validate_rsync_args_structure(argv, option_index, is_read_request)
if error:
return error, -1, is_read_request
return None, path_index, is_read_request
def _step_03_validate_rsync_args_structure(
argv: list[str], option_index: int, is_read_request: bool
) -> tuple[str | None, int]:
"""Validate the dot argument and path argument presence and count."""
# READ: ['rsync', '--server', '--sender', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/'] :: 3 :: True
# WRITE: ['rsync', '--server', '-vlogDtpre.iLsfxCIvu', '.', '/proj/v1/'] :: 2 :: False
dot_arg_index = option_index + 1
path_index = option_index + 2
# Write requests might have --delete
has_delete = False
if (not is_read_request) and (len(argv) > dot_arg_index) and (argv[dot_arg_index] == "--delete"):
has_delete = True
dot_arg_index += 1
path_index += 1
if (len(argv) <= dot_arg_index) or (argv[dot_arg_index] != "."):
expected_pos = "fourth" if (is_read_request or (not has_delete)) else "fifth"
return f"The {expected_pos} argument must be .", -1
if len(argv) <= path_index:
return "Missing path argument", -1
# Check expected total number of arguments
expected_len = path_index + 1
if len(argv) != expected_len:
return f"Expected {expected_len} arguments, but got {len(argv)}", -1
return None, path_index
async def _step_04_command_validate(
process: asyncssh.SSHServerProcess, argv: list[str], path_index: int, is_read_request: bool
) -> tuple[str, str, models.Release | None] | None:
"""Validate the path and user permissions for read or write."""
############################################
### Calls _step_05_command_path_validate ###
############################################
result = _step_05_command_path_validate(argv[path_index])
if isinstance(result, str):
return _fail(process, result, None)
path_project, path_version = result
ssh_uid = process.get_extra_info("username")
async with db.session() as data:
project = await data.project(name=path_project, _committee=True).get()
if project is None:
# Projects are public, so existence information is public
return _fail(process, f"Project '{path_project}' does not exist", None)
release = await data.release(project_name=project.name, version=path_version).get()
if is_read_request:
#################################################
### Calls _step_06a_validate_read_permissions ###
#################################################
validated_release, success = await _step_06a_validate_read_permissions(
process, ssh_uid, project, release, path_project, path_version
)
if success is None:
return None
return path_project, path_version, validated_release
else:
##################################################
### Calls _step_06b_validate_write_permissions ###
##################################################
success = await _step_06b_validate_write_permissions(process, ssh_uid, project, release)
if success is None:
return None
# Return None for the release object for write requests
return path_project, path_version, None
def _step_05_command_path_validate(path: str) -> tuple[str, str] | str:
"""Validate the path argument for rsync commands."""
# READ: rsync --server --sender -vlogDtpre.iLsfxCIvu . /proj/v1/
# Validating path: /proj/v1/
# WRITE: rsync --server -vlogDtpre.iLsfxCIvu . /proj/v1/
# Validating path: /proj/v1/
if not path.startswith("/"):
return "The path argument should be an absolute path"
if not path.endswith("/"):
# Technically we could ignore this, because we rewrite the path anyway for writes
# But we should enforce good rsync usage practices
return "The path argument should be a directory path, ending with a /"
if "//" in path:
return "The path argument should not contain //"
if path.count("/") != 3:
return "The path argument should be a /PROJECT/VERSION/ directory path"
path_project, path_version = path.strip("/").split("/", 1)
alphanum = set(string.ascii_letters + string.digits + "-")
if not all(c in alphanum for c in path_project):
return "The project name should contain only alphanumeric characters or hyphens"
# From a survey of version numbers we find that only . and - are used
# We also allow + which is in common use
version_punctuation = set(".-+")
if path_version[0] not in alphanum:
# Must certainly not allow the directory to be called "." or ".."
# And we also want to avoid patterns like ".htaccess"
return "The version should start with an alphanumeric character"
if path_version[-1] not in alphanum:
return "The version should end with an alphanumeric character"
if not all(c in (alphanum | version_punctuation) for c in path_version):
return "The version should contain only alphanumeric characters, dots, dashes, or pluses"
return path_project, path_version
async def _step_06a_validate_read_permissions(
process: asyncssh.SSHServerProcess,
ssh_uid: str,
project: models.Project,
release: models.Release | None,
path_project: str,
path_version: str,
) -> tuple[models.Release | None, bool]:
"""Validate permissions for a read request."""
if release is None:
_fail(process, f"Release '{path_project}-{path_version}' does not exist", None)
return None, False
allowed_read_phases = {
models.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
models.ReleasePhase.RELEASE_CANDIDATE,
models.ReleasePhase.RELEASE_PREVIEW,
}
if release.phase not in allowed_read_phases:
_fail(process, f"Release '{release.name}' is not in a readable phase ({release.phase.value})", None)
return None, False
if not user.is_committer(project.committee, ssh_uid):
_fail(
process,
f"You must be a committer or committee member for project '{project.name}' to read this release",
None,
)
return None, False
return release, True
async def _step_06b_validate_write_permissions(
process: asyncssh.SSHServerProcess,
ssh_uid: str,
project: models.Project,
release: models.Release | None,
) -> bool:
"""Validate permissions for a write request."""
if release is None:
# Creating a new release requires committee membership
if not user.is_committee_member(project.committee, ssh_uid):
return _fail(
process,
f"You must be a member of project '{project.name}' committee to create a release",
False,
)
else:
# Uploading to existing release, requires DRAFT and participant status
if release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
return _fail(
process,
f"Cannot upload: Release '{release.name}' is no longer in draft phase ({release.phase.value})",
False,
)
if not user.is_committer(project.committee, ssh_uid):
return _fail(
process,
f"You must be a committer or committee member for project '{project.name}' "
"to upload to this draft release",
False,
)
return True
async def _step_07a_process_validated_rsync_read(
process: asyncssh.SSHServerProcess,
argv: list[str],
path_index: int,
release: models.Release,
) -> None:
"""Handle a validated rsync read request."""
exit_status = 1
try:
# Determine the source directory based on the release phase and revision
source_dir = util.release_directory(release)
_LOGGER.info(
f"Identified source directory for read: {source_dir} for release "
f"{release.name} (phase {release.phase.value})"
)
# Check whether the source directory actually exists before proceeding
if not await aiofiles.os.path.isdir(source_dir):
return _fail(process, f"Source directory '{source_dir}' not found for release {release.name}", None)
# Update the rsync command path to the determined source directory
argv[path_index] = str(source_dir)
if not argv[path_index].endswith("/"):
argv[path_index] += "/"
###################################################
### Calls _step_08_execute_rsync_sender_command ###
###################################################
exit_status = await _step_08_execute_rsync(process, argv)
if exit_status != 0:
_LOGGER.error(
f"rsync --sender failed with exit status {exit_status} for release {release.name}. "
f"Command: {process.command} (run as {' '.join(argv)})"
)
process.exit(exit_status)
except Exception as e:
_LOGGER.exception(f"Error during rsync read processing for {release.name}")
_fail(process, f"Internal error processing read request: {e}", None)
process.exit(1)
async def _step_07b_process_validated_rsync_write(
process: asyncssh.SSHServerProcess,
argv: list[str],
path_index: int,
project_name: str,
version_name: str,
) -> None:
"""Handle a validated rsync write request."""
asf_uid = process.get_extra_info("username")
exit_status = 1
try:
# Ensure the release object exists or is created
# This must happen before creating the revision directory
#######################################################
### Calls _step_07c_ensure_release_object_for_write ###
#######################################################
if not await _step_07c_ensure_release_object_for_write(process, project_name, version_name):
# The _fail function was already called in _07b2_ensure_release_object_for_write
return
# Create the draft revision directory structure
async with revision.create_and_manage(project_name, version_name, asf_uid) as (
new_revision_dir,
new_draft_revision,
):
_LOGGER.info(f"Created draft revision directory: {new_revision_dir} ({new_draft_revision})")
# Update the rsync command path to the new revision directory
argv[path_index] = str(new_revision_dir)
###################################################
### Calls _step_08_execute_rsync_upload_command ###
###################################################
exit_status = await _step_08_execute_rsync(process, argv)
if exit_status != 0:
_LOGGER.error(
f"rsync upload failed with exit status {exit_status} for revision {new_draft_revision}. "
f"Command: {process.command} (run as {' '.join(argv)})"
)
_LOGGER.info(f"rsync upload successful for revision {new_draft_revision}")
# Close the connection unconditionally
# If we use "if not process.is_closing():" then it fails
process.exit(exit_status)
except Exception as e:
_LOGGER.exception(f"Error during draft revision processing for {project_name}-{version_name}")
_fail(process, f"Internal error processing upload revision: {e}", None)
process.exit(1)
async def _step_07c_ensure_release_object_for_write(
process: asyncssh.SSHServerProcess, project_name: str, version_name: str
) -> bool:
"""Ensure the release object exists or create it for a write operation."""
try:
async with db.session() as data:
async with data.begin():
release = await data.release(
name=models.release_name(project_name, version_name), _committee=True
).get()
if release is None:
project = await data.project(name=project_name, _committee=True).demand(
RuntimeError("Project not found after validation")
)
if version_name_error := util.version_name_error(version_name):
# This should ideally be caught by path validation, but double check
raise RuntimeError(f'Invalid version name "{version_name}": {version_name_error}')
# Create a new release object
_LOGGER.info(f"Creating new release object for {project_name}-{version_name}")
release = models.Release(
project_name=project.name,
project=project,
version=version_name,
stage=models.ReleaseStage.RELEASE_CANDIDATE,
phase=models.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
created=datetime.datetime.now(datetime.UTC),
)
data.add(release)
elif release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
return _fail(
process,
f"Release '{release.name}' is no longer in draft phase ({release.phase.value}) "
"- cannot create new revision",
False,
)
return True
except Exception as e:
_LOGGER.exception(f"Error ensuring release object for write: {project_name}-{version_name}")
return _fail(process, f"Internal error ensuring release object: {e}", False)
async def _step_08_execute_rsync(process: asyncssh.SSHServerProcess, argv: list[str]) -> int:
"""Execute the modified rsync command."""
_LOGGER.info(f"Executing modified rsync command: {' '.join(argv)}")
proc = await asyncio.create_subprocess_shell(
" ".join(argv),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# Redirect the client's streams to the rsync process
await process.redirect(stdin=proc.stdin, stdout=proc.stdout, stderr=proc.stderr)
# Wait for rsync to finish and get its exit status
exit_status = await proc.wait()
_LOGGER.info(f"Rsync finished with exit status {exit_status}")
return exit_status