atr/tasks/bulk.py (431 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import dataclasses
import html.parser
import json
import logging
import os
import urllib.parse
from typing import Any, Final
import aiofiles
import aiohttp
import sqlalchemy
import sqlalchemy.ext.asyncio
import atr.db.models as models
import atr.tasks.task as task
from atr import config
# Configure detailed logging
_LOGGER: Final = logging.getLogger(__name__)
_LOGGER.setLevel(logging.DEBUG)
# Create file handler for test.log
file_handler: Final[logging.FileHandler] = logging.FileHandler("tasks-bulk.log")
file_handler.setLevel(logging.DEBUG)
# Create formatter with detailed information
formatter: Final[logging.Formatter] = logging.Formatter(
"[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] [%(name)s:%(funcName)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler.setFormatter(formatter)
_LOGGER.addHandler(file_handler)
# Ensure parent loggers don't duplicate messages
_LOGGER.propagate = False
_LOGGER.info("Bulk download module imported")
global_db_connection: sqlalchemy.ext.asyncio.async_sessionmaker | None = None
global_task_id: int | None = None
# TODO: Use a Pydantic model instead
@dataclasses.dataclass
class Args:
release_name: str
base_url: str
file_types: list[str]
require_sigs: bool
max_depth: int
max_concurrent: int
@staticmethod
def from_dict(args: dict[str, Any]) -> "Args":
"""Parse command line arguments."""
_LOGGER.debug(f"Parsing arguments: {args}")
if len(args) != 6:
_LOGGER.error(f"Invalid number of arguments: {len(args)}, expected 6")
raise ValueError("Invalid number of arguments")
release_name = args["release_name"]
base_url = args["base_url"]
file_types = args["file_types"]
require_sigs = args["require_sigs"]
max_depth = args["max_depth"]
max_concurrent = args["max_concurrent"]
_LOGGER.debug(
f"Extracted values - release_name: {release_name}, base_url: {base_url}, "
f"file_types: {file_types}, require_sigs: {require_sigs}, "
f"max_depth: {max_depth}, max_concurrent: {max_concurrent}"
)
if not isinstance(release_name, str):
_LOGGER.error(f"Release key must be a string, got {type(release_name)}")
raise ValueError("Release key must be a string")
if not isinstance(base_url, str):
_LOGGER.error(f"Base URL must be a string, got {type(base_url)}")
raise ValueError("Base URL must be a string")
if not isinstance(file_types, list):
_LOGGER.error(f"File types must be a list, got {type(file_types)}")
raise ValueError("File types must be a list")
for arg in file_types:
if not isinstance(arg, str):
_LOGGER.error(f"File types must be a list of strings, got {type(arg)}")
raise ValueError("File types must be a list of strings")
if not isinstance(require_sigs, bool):
_LOGGER.error(f"Require sigs must be a boolean, got {type(require_sigs)}")
raise ValueError("Require sigs must be a boolean")
if not isinstance(max_depth, int):
_LOGGER.error(f"Max depth must be an integer, got {type(max_depth)}")
raise ValueError("Max depth must be an integer")
if not isinstance(max_concurrent, int):
_LOGGER.error(f"Max concurrent must be an integer, got {type(max_concurrent)}")
raise ValueError("Max concurrent must be an integer")
_LOGGER.debug("All argument validations passed")
args_obj = Args(
release_name=release_name,
base_url=base_url,
file_types=file_types,
require_sigs=require_sigs,
max_depth=max_depth,
max_concurrent=max_concurrent,
)
_LOGGER.info(f"Args object created: {args_obj}")
return args_obj
class LinkExtractor(html.parser.HTMLParser):
def __init__(self) -> None:
super().__init__()
self.links: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag == "a":
for attr, value in attrs:
if attr == "href" and value:
self.links.append(value)
async def artifact_download(url: str, semaphore: asyncio.Semaphore) -> bool:
_LOGGER.debug(f"Starting download of artifact: {url}")
try:
success = await artifact_download_core(url, semaphore)
if success:
_LOGGER.info(f"Successfully downloaded artifact: {url}")
else:
_LOGGER.warning(f"Failed to download artifact: {url}")
return success
except Exception as e:
_LOGGER.exception(f"Error downloading artifact {url}: {e}")
return False
async def artifact_download_core(url: str, semaphore: asyncio.Semaphore) -> bool:
_LOGGER.debug(f"Starting core download process for {url}")
async with semaphore:
_LOGGER.debug(f"Acquired semaphore for {url}")
# TODO: We flatten the hierarchy to get the filename
# We should preserve the hierarchy
filename = url.split("/")[-1]
if filename.startswith("."):
raise ValueError(f"Invalid filename: {filename}")
local_path = os.path.join("downloads", filename)
# Create download directory if it doesn't exist
# TODO: Check whether local_path itself exists first
os.makedirs("downloads", exist_ok=True)
_LOGGER.debug(f"Downloading {url} to {local_path}")
try:
async with aiohttp.ClientSession() as session:
_LOGGER.debug(f"Created HTTP session for {url}")
async with session.get(url) as response:
if response.status != 200:
_LOGGER.warning(f"Failed to download {url}: HTTP {response.status}")
return False
total_size = int(response.headers.get("Content-Length", 0))
if total_size:
_LOGGER.info(f"Content-Length: {total_size} bytes for {url}")
chunk_size = 8192
downloaded = 0
_LOGGER.debug(f"Writing file to {local_path} with chunk size {chunk_size}")
async with aiofiles.open(local_path, "wb") as f:
async for chunk in response.content.iter_chunked(chunk_size):
await f.write(chunk)
downloaded += len(chunk)
# if total_size:
# progress = (downloaded / total_size) * 100
# if downloaded % (chunk_size * 128) == 0:
# _LOGGER.debug(
# f"Download progress for {filename}:"
# f" {progress:.1f}% ({downloaded}/{total_size} bytes)"
# )
_LOGGER.info(f"Download complete: {url} -> {local_path} ({downloaded} bytes)")
return True
except Exception as e:
_LOGGER.exception(f"Error during download of {url}: {e}")
# Remove partial download if an error occurred
if os.path.exists(local_path):
_LOGGER.debug(f"Removing partial download: {local_path}")
try:
os.remove(local_path)
except Exception as del_err:
_LOGGER.error(f"Error removing partial download {local_path}: {del_err}")
return False
async def artifact_urls(args: Args, queue: asyncio.Queue, semaphore: asyncio.Semaphore) -> tuple[list[str], list[str]]:
_LOGGER.info(f"Starting URL crawling from {args.base_url}")
await database_message(f"Crawling artifact URLs from {args.base_url}")
signatures: list[str] = []
artifacts: list[str] = []
seen: set[str] = set()
_LOGGER.debug(f"Adding base URL to queue: {args.base_url}")
await queue.put(args.base_url)
_LOGGER.debug("Starting crawl loop")
depth = 0
# Start with just the base URL
urls_at_current_depth = 1
urls_at_next_depth = 0
while (not queue.empty()) and (depth < args.max_depth):
_LOGGER.debug(f"Processing depth {depth + 1}/{args.max_depth}, queue size: {queue.qsize()}")
# Process all URLs at the current depth before moving to the next
for _ in range(urls_at_current_depth):
if queue.empty():
break
url = await queue.get()
_LOGGER.debug(f"Processing URL: {url}")
if url_excluded(seen, url, args):
continue
seen.add(url)
_LOGGER.debug(f"Checking URL for file types: {args.file_types}")
# If not a target file type, try to parse HTML links
if not check_matches(args, url, artifacts, signatures):
_LOGGER.debug(f"URL is not a target file, parsing HTML: {url}")
try:
new_urls = await download_html(url, semaphore)
_LOGGER.debug(f"Found {len(new_urls)} new URLs in {url}")
for new_url in new_urls:
if new_url not in seen:
_LOGGER.debug(f"Adding new URL to queue: {new_url}")
await queue.put(new_url)
urls_at_next_depth += 1
except Exception as e:
_LOGGER.warning(f"Error parsing HTML from {url}: {e}")
# Move to next depth
depth += 1
urls_at_current_depth = urls_at_next_depth
urls_at_next_depth = 0
# Update database with progress message
progress_msg = f"Crawled {len(seen)} URLs, found {len(artifacts)} artifacts (depth {depth}/{args.max_depth})"
await database_message(progress_msg, progress=(30 + min(50, depth * 10), 100))
_LOGGER.debug(f"Moving to depth {depth + 1}, {urls_at_current_depth} URLs to process")
_LOGGER.info(f"URL crawling complete. Found {len(artifacts)} artifacts and {len(signatures)} signatures")
return signatures, artifacts
async def artifacts_download(artifacts: list[str], semaphore: asyncio.Semaphore) -> list[str]:
"""Download artifacts with progress tracking."""
size = len(artifacts)
_LOGGER.info(f"Starting download of {size} artifacts")
downloaded = []
for i, artifact in enumerate(artifacts):
progress_percent = int((i / size) * 100) if (size > 0) else 100
progress_msg = f"Downloading {i + 1}/{size} artifacts"
_LOGGER.info(f"{progress_msg}: {artifact}")
await database_message(progress_msg, progress=(progress_percent, 100))
success = await artifact_download(artifact, semaphore)
if success:
_LOGGER.debug(f"Successfully downloaded: {artifact}")
downloaded.append(artifact)
else:
_LOGGER.warning(f"Failed to download: {artifact}")
_LOGGER.info(f"Download complete. Successfully downloaded {len(downloaded)}/{size} artifacts")
await database_message(f"Downloaded {len(downloaded)} artifacts", progress=(100, 100))
return downloaded
def check_matches(args: Args, url: str, artifacts: list[str], signatures: list[str]) -> bool:
for type in args.file_types:
if url.endswith(type):
_LOGGER.info(f"Found artifact: {url}")
artifacts.append(url)
return True
elif url.endswith(type + ".asc"):
_LOGGER.info(f"Found signature: {url}")
signatures.append(url)
return True
return False
async def database_message(msg: str, progress: tuple[int, int] | None = None) -> None:
"""Update database with message and progress."""
_LOGGER.debug(f"Updating database with message: '{msg}', progress: {progress}")
try:
task_id = await database_task_id_get()
if task_id:
_LOGGER.debug(f"Found task_id: {task_id}, updating with message")
await database_task_update(task_id, msg, progress)
else:
_LOGGER.warning("No task ID found, skipping database update")
except Exception as e:
# We don't raise here
# We continue even if database updates fail
# But in this case, the user won't be informed on the update page
_LOGGER.exception(f"Failed to update database: {e}")
_LOGGER.info(f"Continuing despite database error. Message was: '{msg}'")
def database_progress_percentage_calculate(progress: tuple[int, int] | None) -> int:
"""Calculate percentage from progress tuple."""
_LOGGER.debug(f"Calculating percentage from progress tuple: {progress}")
if progress is None:
_LOGGER.debug("Progress is None, returning 0%")
return 0
current, total = progress
# Avoid division by zero
if total == 0:
_LOGGER.warning("Total is zero in progress tuple, avoiding division by zero")
return 0
percentage = min(100, int((current / total) * 100))
_LOGGER.debug(f"Calculated percentage: {percentage}% ({current}/{total})")
return percentage
async def database_task_id_get() -> int | None:
"""Get current task ID asynchronously with caching."""
global global_task_id
_LOGGER.debug("Attempting to get current task ID")
# Return cached ID if available
if global_task_id is not None:
_LOGGER.debug(f"Using cached task ID: {global_task_id}")
return global_task_id
try:
process_id = os.getpid()
_LOGGER.debug(f"Current process ID: {process_id}")
task_id = await database_task_pid_lookup(process_id)
if task_id:
_LOGGER.info(f"Found task ID: {task_id} for process ID: {process_id}")
# Cache the task ID for future use
global_task_id = task_id
else:
_LOGGER.warning(f"No task found for process ID: {process_id}")
return task_id
except Exception as e:
_LOGGER.exception(f"Error getting task ID: {e}")
return None
async def database_task_pid_lookup(process_id: int) -> int | None:
"""Look up task ID by process ID asynchronously."""
_LOGGER.debug(f"Looking up task ID for process ID: {process_id}")
try:
async with await get_db_session() as session:
_LOGGER.debug(f"Executing SQL query to find task for PID: {process_id}")
# Look for ACTIVE task with our PID
result = await session.execute(
sqlalchemy.text("""
SELECT id FROM task
WHERE pid = :pid AND status = 'ACTIVE'
LIMIT 1
"""),
{"pid": process_id},
)
_LOGGER.debug("SQL query executed, fetching results")
row = result.fetchone()
if row:
_LOGGER.info(f"Found task ID: {row[0]} for process ID: {process_id}")
row_one = row[0]
if not isinstance(row_one, int):
_LOGGER.error(f"Task ID is not an integer: {row_one}")
raise ValueError("Task ID is not an integer")
return row_one
else:
_LOGGER.warning(f"No ACTIVE task found for process ID: {process_id}")
return None
except Exception as e:
_LOGGER.exception(f"Error looking up task by PID: {e}")
return None
async def database_task_update(task_id: int, msg: str, progress: tuple[int, int] | None) -> None:
"""Update task in database with message and progress."""
_LOGGER.debug(f"Updating task {task_id} with message: '{msg}', progress: {progress}")
# Convert progress to percentage
progress_pct = database_progress_percentage_calculate(progress)
_LOGGER.debug(f"Calculated progress percentage: {progress_pct}%")
await database_task_update_execute(task_id, msg, progress_pct)
async def database_task_update_execute(task_id: int, msg: str, progress_pct: int) -> None:
"""Execute database update with message and progress."""
_LOGGER.debug(f"Executing database update for task {task_id}, message: '{msg}', progress: {progress_pct}%")
try:
async with await get_db_session() as session:
_LOGGER.debug(f"Executing SQL UPDATE for task ID: {task_id}")
# Store progress info in the result column as JSON
result_data = json.dumps({"message": msg, "progress": progress_pct})
await session.execute(
sqlalchemy.text("""
UPDATE task
SET result = :result
WHERE id = :task_id
"""),
{
"result": result_data,
"task_id": task_id,
},
)
await session.commit()
_LOGGER.info(f"Successfully updated task {task_id} with progress {progress_pct}%")
except Exception as e:
# Continue even if database update fails
_LOGGER.exception(f"Error updating task {task_id} in database: {e}")
async def download(args: dict[str, Any]) -> tuple[models.TaskStatus, str | None, tuple[Any, ...]]:
"""Download bulk package from URL."""
# Returns (status, error, result)
# This is the main task entry point, called by worker.py
# This function should probably be called artifacts_download
_LOGGER.info(f"Starting bulk download task with args: {args}")
try:
_LOGGER.debug("Delegating to download_core function")
status, error, result = await download_core(args)
_LOGGER.info(f"Download completed with status: {status}")
return status, error, result
except Exception as e:
_LOGGER.exception(f"Error in download function: {e}")
# Return a tuple with a dictionary that matches what the template expects
return task.FAILED, str(e), ({"message": f"Error: {e}", "progress": 0},)
async def download_core(args_dict: dict[str, Any]) -> tuple[models.TaskStatus, str | None, tuple[Any, ...]]:
"""Download bulk package from URL."""
_LOGGER.info("Starting download_core")
try:
_LOGGER.debug(f"Parsing arguments: {args_dict}")
args = Args.from_dict(args_dict)
_LOGGER.info(f"Args parsed successfully: release_name={args.release_name}, base_url={args.base_url}")
# Create async resources
_LOGGER.debug("Creating async queue and semaphore")
queue: asyncio.Queue[str] = asyncio.Queue()
semaphore = asyncio.Semaphore(args.max_concurrent)
# Start URL crawling
await database_message(f"Crawling URLs from {args.base_url}")
_LOGGER.info("Starting artifact_urls coroutine")
signatures, artifacts = await artifact_urls(args, queue, semaphore)
_LOGGER.info(f"Found {len(signatures)} signatures and {len(artifacts)} artifacts")
# Update progress for download phase
await database_message(f"Found {len(artifacts)} artifacts to download")
# Download artifacts
_LOGGER.info("Starting artifacts_download coroutine")
artifacts_downloaded = await artifacts_download(artifacts, semaphore)
files_downloaded = len(artifacts_downloaded)
# Return a result dictionary
# This matches what we have in templates/release-bulk.html
return (
task.COMPLETED,
None,
(
{
"message": f"Successfully downloaded {files_downloaded} artifacts",
"progress": 100,
"url": args.base_url,
"file_types": args.file_types,
"files_downloaded": files_downloaded,
},
),
)
except Exception as e:
_LOGGER.exception(f"Error in download_core: {e}")
base_url = args_dict["base_url"] if len(args_dict) > 1 else "unknown URL"
return (
task.FAILED,
str(e),
(
{
"message": f"Failed to download from {base_url}",
"progress": 0,
},
),
)
async def download_html(url: str, semaphore: asyncio.Semaphore) -> list[str]:
"""Download HTML and extract links."""
_LOGGER.debug(f"Downloading HTML from: {url}")
try:
return await download_html_core(url, semaphore)
except Exception as e:
_LOGGER.error(f"Error downloading HTML from {url}: {e}")
return []
async def download_html_core(url: str, semaphore: asyncio.Semaphore) -> list[str]:
"""Core HTML download and link extraction logic."""
_LOGGER.debug(f"Starting HTML download core for {url}")
async with semaphore:
_LOGGER.debug(f"Acquired semaphore for {url}")
urls = []
async with aiohttp.ClientSession() as session:
_LOGGER.debug(f"Created HTTP session for {url}")
async with session.get(url) as response:
if response.status != 200:
_LOGGER.warning(f"HTTP {response.status} for {url}")
return []
_LOGGER.debug(f"Received HTTP 200 for {url}, content type: {response.content_type}")
if not response.content_type.startswith("text/html"):
_LOGGER.debug(f"Not HTML content: {response.content_type}, skipping link extraction")
return []
_LOGGER.debug(f"Reading HTML content from {url}")
html = await response.text()
urls = extract_links_from_html(html, url)
_LOGGER.debug(f"Extracted {len(urls)} processed links from {url}")
return urls
def extract_links_from_html(html: str, base_url: str) -> list[str]:
"""Extract links from HTML content using html.parser."""
parser = LinkExtractor()
parser.feed(html)
raw_links = parser.links
_LOGGER.debug(f"Found {len(raw_links)} raw links in {base_url}")
processed_urls = []
for link in raw_links:
processed_url = urllib.parse.urljoin(base_url, link)
# Filter out URLs that don't start with the base URL
# We also check this elsewhere amongst other checks
# But it's good to filter them early
if processed_url.startswith(base_url):
processed_urls.append(processed_url)
else:
_LOGGER.debug(f"Skipping URL outside base URL scope: {processed_url}")
return processed_urls
async def get_db_session() -> sqlalchemy.ext.asyncio.AsyncSession:
"""Get a reusable database session."""
global global_db_connection
try:
# Create connection only if it doesn't exist already
if global_db_connection is None:
conf = config.get()
absolute_db_path = os.path.join(conf.STATE_DIR, conf.SQLITE_DB_PATH)
# Three slashes are required before either a relative or absolute path
db_url = f"sqlite+aiosqlite://{absolute_db_path}"
_LOGGER.debug(f"Creating database engine: {db_url}")
engine = sqlalchemy.ext.asyncio.create_async_engine(db_url)
global_db_connection = sqlalchemy.ext.asyncio.async_sessionmaker(
engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False
)
connection: sqlalchemy.ext.asyncio.AsyncSession = global_db_connection()
return connection
except Exception as e:
_LOGGER.exception(f"Error creating database session: {e}")
raise
def url_excluded(seen: set[str], url: str, args: Args) -> bool:
# Filter for sorting URLs to avoid redundant crawling
sorting_patterns = ["?C=N;O=", "?C=M;O=", "?C=S;O=", "?C=D;O="]
if not url.startswith(args.base_url):
_LOGGER.debug(f"Skipping URL outside base URL scope: {url}")
return True
if url in seen:
_LOGGER.debug(f"Skipping already seen URL: {url}")
return True
# Skip sorting URLs to avoid redundant crawling
if any(pattern in url for pattern in sorting_patterns):
_LOGGER.debug(f"Skipping sorting URL: {url}")
return True
return False