progress_tracking/src/verification_wrapper.rs (218 lines of code) (raw):

use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use more_asserts::assert_le; use tokio::sync::Mutex; use crate::{ProgressUpdate, TrackingProgressUpdater}; /// Internal structure to track and validate progress data for one item. #[derive(Debug)] struct ItemProgressData { total_count: u64, last_completed: u64, } #[derive(Debug, Default)] pub struct ProgressUpdaterVerificationWrapperImpl { items: HashMap<Arc<str>, ItemProgressData>, total_transfer_bytes: u64, total_transfer_bytes_completed: u64, total_bytes: u64, total_process_bytes_completed: u64, } /// A wrapper that forwards updates to an inner `TrackingProgressUpdater` /// while also validating each update for correctness: /// /// - `completed_count` must be non-decreasing and never exceed `total_count`. /// - `completed_count` must match `last_completed + update_increment`. /// - `total_count` must remain consistent (if it changes across updates for the same item, that's an error). /// - Final verification (`assert_complete()`) ensures all items reached `completed_count == total_count`. pub struct ProgressUpdaterVerificationWrapper { inner: Arc<dyn TrackingProgressUpdater>, tr: Mutex<ProgressUpdaterVerificationWrapperImpl>, } impl ProgressUpdaterVerificationWrapper { /// Creates a new verification wrapper around an existing `TrackingProgressUpdater`. /// All updates are validated and then forwarded to `inner`. pub fn new(inner: Arc<dyn TrackingProgressUpdater>) -> Arc<Self> { Arc::new(Self { inner, tr: Mutex::new(ProgressUpdaterVerificationWrapperImpl::default()), }) } /// Once all uploads are done, call this to ensure that every item is fully complete. /// Panics if any item is still incomplete (i.e. `last_completed < total_count`). pub async fn assert_complete(&self) { let tr = self.tr.lock().await; for (item_name, data) in tr.items.iter() { assert_eq!( data.last_completed, data.total_count, "Item '{}' is not fully complete: {}/{}", item_name, data.last_completed, data.total_count ); } assert_eq!(tr.total_transfer_bytes_completed, tr.total_transfer_bytes); } } #[async_trait] impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { async fn register_updates(&self, update: ProgressUpdate) { // First, capture and validate let mut tr = self.tr.lock().await; for up in update.item_updates.iter() { let entry = tr.items.entry(up.item_name.clone()).or_insert(ItemProgressData { total_count: 0, last_completed: 0, }); // If first time seeing total_count for this item, record it. // Otherwise, ensure it stays consistent. if entry.total_count == 0 { entry.total_count = up.total_bytes; } else { assert_eq!( entry.total_count, up.total_bytes, "Inconsistent total_count for '{}'; was {}, now {}", up.item_name, entry.total_count, up.total_bytes ); } // Check increments: // 1) `completed_count` should never go down assert!( up.bytes_completed >= entry.last_completed, "Item '{}' completed_count went backwards: old={}, new={}", up.item_name, entry.last_completed, up.bytes_completed ); // 2) `completed_count` must not exceed `total_count` assert!( up.bytes_completed <= up.total_bytes, "Item '{}' completed_count {} exceeds total {}", up.item_name, up.bytes_completed, up.total_bytes ); // 3) The increment must match the difference let expected_new = entry.last_completed + up.bytes_completion_increment; assert_eq!( up.bytes_completed, expected_new, "Item '{}': mismatch: last_completed={} + update_increment={} != completed_count={}", up.item_name, entry.last_completed, up.bytes_completion_increment, up.bytes_completed ); // Update item record entry.last_completed = up.bytes_completed; } assert_le!( tr.total_transfer_bytes, update.total_transfer_bytes, "New total bytes {} a decrease from previous report of total bytes {}", update.total_transfer_bytes, tr.total_transfer_bytes ); tr.total_transfer_bytes += update.total_transfer_bytes_increment; assert_eq!( tr.total_transfer_bytes, update.total_transfer_bytes, "New increment {} put tracked checked transfer bytes {} out of step from reported total bytes {}", update.total_transfer_bytes_increment, tr.total_transfer_bytes, update.total_transfer_bytes, ); assert_le!( tr.total_transfer_bytes_completed, update.total_transfer_bytes_completed, "New total bytes completed {} a decrease from previous report of total bytes {}", update.total_transfer_bytes_completed, tr.total_transfer_bytes_completed ); tr.total_transfer_bytes_completed += update.total_transfer_bytes_completion_increment; assert_eq!( tr.total_transfer_bytes_completed, update.total_transfer_bytes_completed, "Total bytes completed {} does not match tracked total bytes {}", update.total_transfer_bytes_completed, tr.total_transfer_bytes_completed ); assert_le!( tr.total_bytes, update.total_bytes, "New total bytes {} a decrease from previous report of total bytes {}", update.total_bytes, tr.total_bytes ); tr.total_bytes += update.total_bytes_increment; assert_eq!( tr.total_bytes, update.total_bytes, "New increment {} put checked total processing bytes {} out of step from reported total bytes {}", update.total_bytes_increment, tr.total_bytes, update.total_bytes, ); assert_le!( tr.total_process_bytes_completed, update.total_bytes_completed, "New total bytes completed {} a decrease from previous report of total bytes {}", update.total_bytes_completed, tr.total_process_bytes_completed ); tr.total_process_bytes_completed += update.total_bytes_completion_increment; assert_eq!( tr.total_process_bytes_completed, update.total_bytes_completed, "Total bytes completed {} does not match tracked total bytes {}", update.total_bytes_completed, tr.total_process_bytes_completed ); // Now forward them to the inner updater self.inner.register_updates(update).await; } async fn flush(&self) { self.inner.flush().await; } } #[cfg(test)] mod tests { use super::*; use crate::ItemProgressUpdate; /// A trivial `TrackingProgressUpdater` for testing, which just stores all updates. /// In real code, this could log to a file, update a UI, etc. #[derive(Debug, Default)] struct DummyLogger { pub all_updates: Mutex<Vec<ItemProgressUpdate>>, } #[async_trait] impl TrackingProgressUpdater for DummyLogger { async fn register_updates(&self, updates: ProgressUpdate) { let mut guard = self.all_updates.lock().await; guard.extend_from_slice(&updates.item_updates); } } #[tokio::test] async fn test_verification_wrapper() { // Create an actual inner logger or progress sink let logger = Arc::new(DummyLogger::default()); // Wrap it with our verification wrapper let wrapper = ProgressUpdaterVerificationWrapper::new(logger.clone()); // Let's register some progress updates wrapper .register_updates(ProgressUpdate { item_updates: vec![ ItemProgressUpdate { item_name: Arc::from("fileA"), total_bytes: 100, bytes_completed: 50, bytes_completion_increment: 50, }, ItemProgressUpdate { item_name: Arc::from("fileB"), total_bytes: 200, bytes_completed: 100, bytes_completion_increment: 100, }, ], total_transfer_bytes: 100, total_transfer_bytes_increment: 100, total_transfer_bytes_completed: 50, total_transfer_bytes_completion_increment: 50, total_bytes: 200, total_bytes_increment: 200, total_bytes_completed: 100, total_bytes_completion_increment: 100, ..Default::default() }) .await; // Shouldn't be complete yet. We'll do one more set of updates to finalize. wrapper .register_updates(ProgressUpdate { item_updates: vec![ ItemProgressUpdate { item_name: Arc::from("fileA"), total_bytes: 100, bytes_completed: 100, bytes_completion_increment: 50, }, ItemProgressUpdate { item_name: Arc::from("fileB"), total_bytes: 200, bytes_completed: 200, bytes_completion_increment: 100, }, ], total_transfer_bytes: 150, total_transfer_bytes_increment: 50, total_transfer_bytes_completed: 150, total_transfer_bytes_completion_increment: 100, total_bytes: 200, total_bytes_increment: 0, total_bytes_completed: 200, total_bytes_completion_increment: 100, ..Default::default() }) .await; // Now all items should be fully complete wrapper.assert_complete().await; // We can also inspect the inner logger's captured updates: let final_updates = logger.all_updates.lock().await; assert_eq!(final_updates.len(), 4, "We sent 4 updates total"); } }