client/securedrop_client/api_jobs/downloads.py (214 lines of code) (raw):
import binascii
import hashlib
import logging
import math
import os
from tempfile import NamedTemporaryFile
from typing import Any
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session
from securedrop_client.api_jobs.base import SingleObjectApiJob
from securedrop_client.crypto import CryptoError, GpgHelper
from securedrop_client.db import DownloadError, DownloadErrorCodes, File, Message, Reply
from securedrop_client.sdk import API, BaseError
from securedrop_client.sdk import Reply as SdkReply
from securedrop_client.sdk import Submission as SdkSubmission
from securedrop_client.storage import (
mark_as_decrypted,
mark_as_downloaded,
set_message_or_reply_content,
)
from securedrop_client.utils import safe_move
logger = logging.getLogger(__name__)
class DownloadException(Exception):
def __init__(
self, message: str, object_type: type[Reply] | type[Message] | type[File], uuid: str
):
super().__init__(message)
self.object_type = object_type
self.uuid = uuid
class DownloadChecksumMismatchException(DownloadException):
"""
Raised when a download's hash does not match the SecureDrop server's.
"""
class DownloadDecryptionException(DownloadException):
"""
Raised when an error occurs during decryption of a download.
"""
class DownloadJob(SingleObjectApiJob):
"""
Download and decrypt a file that contains either a message, reply, or file submission.
"""
CHUNK_SIZE = 4096
def __init__(self, data_dir: str, uuid: str) -> None:
super().__init__(uuid)
self.data_dir = data_dir
def _get_realistic_timeout(self, size_in_bytes: int) -> int:
"""
Return a realistic timeout in seconds based on the size of the download.
This simply scales the timeouts per file so that it increases as the file size increases.
Note that:
* The size of the file provided by server is in bytes (the server computes it using
os.stat.ST_SIZE).
* The following times are reasonable estimations for how long it should take to fetch a
file over Tor according to Tor metrics given a recent three month period in 2022-2023:
50 KiB (51200 bytes) = 6 seconds (8,533 bytes/second)
1 MiB (1049000 bytes) = 15 seconds (~ 69,905 bytes/second)
For more information, see:
https://metrics.torproject.org/torperf.html?start=2022-12-06&end=2023-03-06&server=onion
* As you might expect, this method returns timeouts that are larger than the expected
download time, which is why the rates below are slower than what you see above with the
Tor metrics, e.g. instead of setting TIMEOUT_BYTES_PER_SECOND to 69905 bytes/second, we
set it to 50000 bytes/second.
* Minimum timeout allowed is 25 seconds
"""
TIMEOUT_BYTES_PER_SECOND = 50_000.0
TIMEOUT_ADJUSTMENT_FACTOR = 1.5
TIMEOUT_BASE = 25
timeout = math.ceil((size_in_bytes / TIMEOUT_BYTES_PER_SECOND) * TIMEOUT_ADJUSTMENT_FACTOR)
return timeout + TIMEOUT_BASE
def call_download_api(self, api: API, db_object: File | Message | Reply) -> tuple[str, str]:
"""
Method for making the actual API call to download the file and handling the result.
This MUST return the (etag, filepath) tuple response from the server and MUST raise an
exception if and only if the download fails.
"""
raise NotImplementedError
def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
"""
Method for decrypting the file and storing the plaintext result.
Returns the original filename.
This MUST raise an exception if and only if the decryption fails.
"""
raise NotImplementedError
def get_db_object(self, session: Session) -> File | Message | Reply:
"""
Get the database object associated with this job; may raise
DownloadException if not found
"""
raise NotImplementedError
def call_api(self, api_client: API, session: Session) -> Any:
"""
Override ApiJob.
Download and decrypt the file associated with the database object.
"""
db_object = self.get_db_object(session)
if db_object.is_decrypted:
logger.debug(f"item with uuid {self.uuid} already decrypted, returning")
return db_object.uuid
if db_object.is_downloaded:
logger.debug(f"item with uuid {self.uuid} already downloaded, now decrypting")
self._decrypt(db_object.location(self.data_dir), db_object, session)
return db_object.uuid
destination = self._download(api_client, db_object, session)
self._decrypt(destination, db_object, session)
return db_object.uuid
def _download(self, api: API, db_object: File | Message | Reply, session: Session) -> str:
"""
Download the encrypted file. Check file integrity and move it to the data directory before
marking it as downloaded.
Note: On Qubes OS, files are downloaded to /home/user/Downloads
"""
try:
etag, download_path = self.call_download_api(api, db_object)
if not self._check_file_integrity(etag, download_path):
download_error = (
session.query(DownloadError)
.filter_by(name=DownloadErrorCodes.CHECKSUM_ERROR.name)
.one()
)
db_object.download_error = download_error
session.commit()
exception = DownloadChecksumMismatchException(
"Downloaded file had an invalid checksum.", type(db_object), db_object.uuid
)
raise exception
destination = db_object.location(self.data_dir)
safe_move(download_path, destination, self.data_dir)
db_object.download_error = None
mark_as_downloaded(type(db_object), db_object.uuid, session)
logger.info(f"File downloaded to {destination}")
return destination
except (ValueError, FileNotFoundError, RuntimeError, BaseError) as e:
logger.error("Download failed")
logger.debug(f"Download failed: {e}")
raise DownloadDecryptionException(
f"Failed to download {db_object.uuid}", type(db_object), db_object.uuid
) from e
def _decrypt(self, filepath: str, db_object: File | Message | Reply, session: Session) -> None:
"""
Decrypt the file located at the given filepath and mark it as decrypted.
"""
try:
original_filename = self.call_decrypt(filepath, session)
db_object.download_error = None
mark_as_decrypted(
type(db_object), db_object.uuid, session, original_filename=original_filename
)
logger.info(f"File decrypted to {os.path.dirname(filepath)}")
except CryptoError as e:
logger.error("Decryption failed")
logger.debug(f"Decryption failed: {e}")
mark_as_decrypted(type(db_object), db_object.uuid, session, is_decrypted=False)
download_error = (
session.query(DownloadError)
.filter_by(name=DownloadErrorCodes.DECRYPTION_ERROR.name)
.one()
)
db_object.download_error = download_error
session.commit()
raise DownloadDecryptionException(
f"Failed to decrypt file: {os.path.basename(filepath)}",
type(db_object),
db_object.uuid,
) from e
@classmethod
def _check_file_integrity(cls, etag: str, file_path: str) -> bool:
"""
Return True if file checksum is valid or unknown, otherwise return False.
"""
if not etag:
logger.debug(f"No ETag. Skipping integrity check for file at {file_path}")
return True
alg, checksum = etag.split(":")
if alg == "sha256":
hasher = hashlib.sha256()
else:
logger.debug(
f"Unknown hash algorithm ({alg}). Skipping integrity check for file at {file_path}"
)
return True
with open(file_path, "rb") as f:
while True:
read_bytes = f.read(cls.CHUNK_SIZE)
if not read_bytes:
break
hasher.update(read_bytes)
calculated_checksum = binascii.hexlify(hasher.digest()).decode("utf-8")
return calculated_checksum == checksum
class ReplyDownloadJob(DownloadJob):
"""
Download and decrypt a reply from a source.
"""
def __init__(self, uuid: str, data_dir: str, gpg: GpgHelper) -> None:
super().__init__(data_dir, uuid)
self.gpg = gpg
def get_db_object(self, session: Session) -> Reply:
"""
Override DownloadJob.
"""
try:
return session.query(Reply).filter_by(uuid=self.uuid).one()
except NoResultFound:
raise DownloadException("Reply not found in database", Reply, self.uuid)
def call_download_api(self, api: API, db_object: Reply) -> tuple[str, str]:
"""
Override DownloadJob.
"""
sdk_object = SdkReply(uuid=db_object.uuid, filename=db_object.filename)
sdk_object.source_uuid = db_object.source.uuid
# TODO: Once https://github.com/freedomofpress/securedrop-sdk/issues/108 is implemented, we
# will want to pass the default request timeout to download_reply instead of setting it on
# the api object directly.
api.default_request_timeout = 20
return api.download_reply(sdk_object)
def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
"""
Override DownloadJob.
Decrypt the file located at the given filepath and store its plaintext content in the local
database.
The file containing the plaintext should be deleted once the content is stored in the db.
The return value is an empty string; replies have no original filename.
"""
with NamedTemporaryFile("w+") as plaintext_file:
try:
self.gpg.decrypt_submission_or_reply(filepath, plaintext_file.name, is_doc=False)
set_message_or_reply_content(
model_type=Reply, uuid=self.uuid, session=session, content=plaintext_file.read()
)
finally:
try:
os.rmdir(os.path.dirname(filepath))
except OSError:
msg = f"Could not delete decryption directory: {os.path.dirname(filepath)}"
logger.debug(msg)
return ""
class MessageDownloadJob(DownloadJob):
"""
Download and decrypt a message from a source.
"""
def __init__(self, uuid: str, data_dir: str, gpg: GpgHelper) -> None:
super().__init__(data_dir, uuid)
self.uuid = uuid
self.gpg = gpg
def get_db_object(self, session: Session) -> Message:
"""
Override DownloadJob.
"""
try:
return session.query(Message).filter_by(uuid=self.uuid).one()
except NoResultFound:
raise DownloadException("Message not found in database", Message, self.uuid)
def call_download_api(self, api: API, db_object: Message) -> tuple[str, str]:
"""
Override DownloadJob.
"""
sdk_object = SdkSubmission(uuid=db_object.uuid)
sdk_object.source_uuid = db_object.source.uuid
sdk_object.filename = db_object.filename
return api.download_submission(
sdk_object, timeout=self._get_realistic_timeout(db_object.size)
)
def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
"""
Override DownloadJob.
Decrypt the file located at the given filepath and store its plaintext content in the local
database.
The file containing the plaintext should be deleted once the content is stored in the db.
The return value is an empty string; messages have no original filename.
"""
with NamedTemporaryFile("w+") as plaintext_file:
try:
self.gpg.decrypt_submission_or_reply(filepath, plaintext_file.name, is_doc=False)
set_message_or_reply_content(
model_type=Message,
uuid=self.uuid,
session=session,
content=plaintext_file.read(),
)
finally:
try:
os.rmdir(os.path.dirname(filepath))
except OSError:
msg = f"Could not delete decryption directory: {os.path.dirname(filepath)}"
logger.debug(msg)
return ""
class FileDownloadJob(DownloadJob):
"""
Download and decrypt a file from a source.
"""
def __init__(self, uuid: str, data_dir: str, gpg: GpgHelper) -> None:
super().__init__(data_dir, uuid)
self.gpg = gpg
def get_db_object(self, session: Session) -> File:
"""
Override DownloadJob.
"""
try:
return session.query(File).filter_by(uuid=self.uuid).one()
except NoResultFound:
raise DownloadException("File not found in database", File, self.uuid)
def call_download_api(self, api: API, db_object: File) -> tuple[str, str]:
"""
Override DownloadJob.
"""
sdk_object = SdkSubmission(uuid=db_object.uuid)
sdk_object.source_uuid = db_object.source.uuid
sdk_object.filename = db_object.filename
return api.download_submission(
sdk_object, timeout=self._get_realistic_timeout(db_object.size)
)
def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
"""
Override DownloadJob.
Decrypt the file located at the given filepath and store its plaintext content in a file on
the filesystem.
The file storing the plaintext should have the same name as the downloaded file but without
the file extensions, e.g. 1-impractical_thing-doc.gz.gpg -> 1-impractical_thing-doc
"""
fn_no_ext, _ = os.path.splitext(os.path.splitext(os.path.basename(filepath))[0])
plaintext_filepath = os.path.join(os.path.dirname(filepath), fn_no_ext)
return self.gpg.decrypt_submission_or_reply(filepath, plaintext_filepath, is_doc=True)