dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs (270 lines of code) (raw):

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use super::Credits; use crate::stream::{ send::{ error::{self, Error}, flow, }, TransportFeatures, }; use atomic_waker::AtomicWaker; use core::{ fmt, sync::atomic::{AtomicU64, Ordering}, task::{Context, Poll}, }; use s2n_quic_core::{ensure, varint::VarInt}; use std::sync::OnceLock; const ERROR_MASK: u64 = 1 << 63; const FINISHED_MASK: u64 = 1 << 62; const OFFSET_MASK: u64 = !(ERROR_MASK | FINISHED_MASK); pub struct State { /// Monotonic offset which tracks where the application is currently writing stream_offset: AtomicU64, /// Monotonic offset which indicates the maximum offset the application can write to flow_offset: AtomicU64, /// Notifies an application of newly-available flow credits poll_waker: AtomicWaker, stream_error: OnceLock<Error>, // TODO add a list for the `acquire` future wakers } impl fmt::Debug for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("flow::non_blocking::State") .field("stream_offset", &self.stream_offset.load(Ordering::Relaxed)) .field("flow_offset", &self.flow_offset.load(Ordering::Relaxed)) .finish() } } impl State { #[inline] pub fn new(initial_flow_offset: VarInt) -> Self { Self { stream_offset: AtomicU64::new(0), flow_offset: AtomicU64::new(initial_flow_offset.as_u64()), poll_waker: AtomicWaker::new(), stream_error: OnceLock::new(), } } } impl State { #[inline] pub fn stream_offset(&self) -> VarInt { let value = self.stream_offset.load(Ordering::Relaxed); // mask off the two upper bits let value = value & OFFSET_MASK; unsafe { VarInt::new_unchecked(value) } } /// Called by the background worker to release flow credits /// /// Callers MUST ensure the provided offset is monotonic. #[inline] pub fn release(&self, flow_offset: VarInt) { tracing::trace!(release = %flow_offset); self.flow_offset .store(flow_offset.as_u64(), Ordering::Release); self.poll_waker.wake(); } /// Called by the background worker to release flow credits /// /// This version only releases credits if the value is strictly less than the previous value #[inline] pub fn release_max(&self, flow_offset: VarInt) { tracing::trace!(release = %flow_offset); let prev = self .flow_offset .fetch_max(flow_offset.as_u64(), Ordering::Release); // if the flow offset was updated then wake the application waker if prev < flow_offset.as_u64() { self.poll_waker.wake(); } } #[inline] pub fn set_error(&self, error: Error) { let _ = self.stream_error.set(error); self.stream_offset.fetch_or(ERROR_MASK, Ordering::Relaxed); self.poll_waker.wake(); } /// Called by the application to acquire flow credits #[inline] pub async fn acquire( &self, request: flow::Request, features: &TransportFeatures, ) -> Result<Credits, Error> { core::future::poll_fn(|cx| self.poll_acquire(cx, request, features)).await } /// Called by the application to acquire flow credits #[inline] pub fn poll_acquire( &self, cx: &mut Context, mut request: flow::Request, features: &TransportFeatures, ) -> Poll<Result<Credits, Error>> { let mut current_offset = self.acquire_offset(&request)?; let mut stored_waker = false; loop { let flow_offset = self.flow_offset.load(Ordering::Acquire); let Some(flow_credits) = flow_offset .checked_sub(current_offset & OFFSET_MASK) .filter(|v| { // if we're finishing the stream and don't have any buffered data, then we // don't need any flow control if request.len == 0 && request.is_fin { true } else { *v > 0 } }) else { // if we already stored a waker and didn't get more credits then yield the task ensure!(!stored_waker, Poll::Pending); stored_waker = true; self.poll_waker.register(cx.waker()); // make one last effort to acquire some flow credits before going to sleep current_offset = self.acquire_offset(&request)?; continue; }; if !features.is_flow_controlled() { // clamp the request to the flow credits we have request.clamp(flow_credits); } let mut new_offset = (current_offset & OFFSET_MASK) .checked_add(request.len as u64) .filter(|v| *v <= VarInt::MAX.as_u64()) .ok_or_else(|| error::Kind::PayloadTooLarge.err())?; // record that we've sent the final offset if request.is_fin || current_offset & FINISHED_MASK == FINISHED_MASK { new_offset |= FINISHED_MASK; } let result = self.stream_offset.compare_exchange( current_offset, new_offset, Ordering::Release, // TODO is this the correct ordering? Ordering::Acquire, ); match result { Ok(_) => { // the offset was correctly updated so return our acquired credits let acquired_offset = unsafe { VarInt::new_unchecked(current_offset & OFFSET_MASK) }; let credits = request.response(acquired_offset); return Poll::Ready(Ok(credits)); } Err(updated_offset) => { // the offset was updated from underneath us so try again current_offset = self.process_offset(updated_offset, &request)?; // clear the fact that we stored the waker, since we need to do a full sync // to get the correct state stored_waker = false; continue; } } } } #[inline] fn acquire_offset(&self, request: &flow::Request) -> Result<u64, Error> { self.process_offset(self.stream_offset.load(Ordering::Acquire), request) } #[inline] fn process_offset(&self, offset: u64, request: &flow::Request) -> Result<u64, Error> { if offset & ERROR_MASK == ERROR_MASK { let error = self .stream_error .get() .copied() .unwrap_or_else(|| error::Kind::FatalError.err()); return Err(error); } if offset & FINISHED_MASK == FINISHED_MASK { ensure!(request.len == 0, Err(error::Kind::FinalSizeChanged.err())); } Ok(offset) } } #[cfg(test)] mod tests { use bolero::check; use super::*; use crate::stream::send::path; use std::sync::Arc; #[tokio::test] async fn concurrent_flow() { let mut initial_offset = VarInt::from_u8(255); let expected_len = VarInt::from_u16(u16::MAX); let state = Arc::new(State::new(initial_offset)); let path_info = path::Info { max_datagram_size: 1500, send_quantum: 10, ecn: Default::default(), next_expected_control_packet: Default::default(), }; let total = Arc::new(AtomicU64::new(0)); // TODO support more than one Waker via intrusive list or something let workers = 1; let worker_counts = Vec::from_iter((0..workers).map(|_| Arc::new(AtomicU64::new(0)))); let features = TransportFeatures::UDP; let mut tasks = tokio::task::JoinSet::new(); for (idx, count) in worker_counts.iter().cloned().enumerate() { let total = total.clone(); let state = state.clone(); tasks.spawn(async move { tokio::time::sleep(core::time::Duration::from_millis(10)).await; let mut buffer_len = 1; let mut is_fin = false; let max_segments = 10; let max_header_len = 50; let mut max_offset = VarInt::ZERO; loop { let mut request = flow::Request { len: buffer_len, initial_len: buffer_len, is_fin, }; request.clamp(path_info.max_flow_credits(max_header_len, max_segments)); let Ok(credits) = state.acquire(request, &features).await else { break; }; println!( "thread={idx} offset={}..{} is_fin={}", credits.offset, credits.offset + credits.len, credits.is_fin, ); buffer_len += 1; buffer_len = buffer_len.min( expected_len .as_u64() .saturating_sub(credits.offset.as_u64()) .saturating_sub(credits.len as u64) as usize, ); assert!(max_offset <= credits.offset); max_offset = credits.offset; if buffer_len == 0 { // we already wrote our fin if is_fin { break; } is_fin = true; } total.fetch_add(credits.len as _, Ordering::Relaxed); count.fetch_add(credits.len as _, Ordering::Relaxed); } }); } tasks.spawn(async move { let mut credits = 10; while initial_offset < expected_len { tokio::time::sleep(core::time::Duration::from_millis(1)).await; initial_offset = (initial_offset + credits).min(expected_len); credits += 1; state.release(initial_offset); } }); // make sure all of the tasks complete while tasks.join_next().await.is_some() {} assert_eq!(total.load(Ordering::Relaxed), expected_len.as_u64()); let mut at_least_one_write = true; for (idx, count) in worker_counts.into_iter().enumerate() { let count = count.load(Ordering::Relaxed); println!("thread={idx}, count={}", count); if count == 0 { at_least_one_write = false; } } assert!( at_least_one_write, "all workers need to write at least one byte" ); } #[test] fn error_test() { check!() .with_type::<(u8, u8, bool)>() .cloned() .for_each(|(initial_offset, len, is_fin)| { let state = Arc::new(State::new(VarInt::from_u8(initial_offset))); state.set_error(Error::new(error::Kind::FatalError)); let len = len as _; let request = flow::Request { len, initial_len: len, is_fin, }; state.acquire_offset(&request).unwrap_err(); }) } }