dc/s2n-quic-dc/src/stream/recv/dispatch/tests.rs (359 lines of code) (raw):

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use super::*; use crate::{ socket::recv, stream::Actor, testing::{ext::*, sim}, }; use bolero::{check, TypeGenerator}; use s2n_quic_core::varint::VarInt; use std::{collections::BTreeMap, panic::AssertUnwindSafe}; #[derive(Clone, Debug, TypeGenerator)] enum Op { Alloc, FreeControl { idx: u16 }, FreeStream { idx: u16 }, SendControl { idx: u16 }, SendStream { idx: u16, inject: bool }, DropAllocator, DropDispatcher, } struct Model { oracle: Oracle, alloc: Option<Allocator>, dispatch: Option<Dispatch>, } impl Default for Model { fn default() -> Self { Self::new(Default::default(), false) } } impl Model { fn new(packets: Packets, non_zero: bool) -> Self { let stream_cap = 32; let control_cap = 8; let alloc = if non_zero { Allocator::new_non_zero(stream_cap, control_cap) } else { Allocator::new(stream_cap, control_cap) }; let dispatch = alloc.dispatcher(); let oracle = Oracle::new(packets); Self { oracle, alloc: Some(alloc), dispatch: Some(dispatch), } } fn apply(&mut self, op: &Op) { match op { Op::Alloc => { self.alloc(); } Op::FreeControl { idx } => { self.free_control((*idx).into()); } Op::FreeStream { idx } => { self.free_stream((*idx).into()); } Op::SendControl { idx } => { self.send_control((*idx).into()); } Op::SendStream { idx, inject } => { self.send_stream((*idx).into(), *inject); } Op::DropAllocator => { self.alloc = None; } Op::DropDispatcher => { self.dispatch = None; } } } fn alloc(&mut self) { let Some(alloc) = self.alloc.as_mut() else { return; }; let (control, stream) = alloc.alloc_or_grow(None); self.oracle.on_alloc(control, stream); } fn free_control(&mut self, idx: VarInt) { let _ = self.oracle.control.remove(&idx); } fn free_stream(&mut self, idx: VarInt) { let _ = self.oracle.stream.remove(&idx); } fn send_control(&mut self, queue_id: VarInt) { let Some(dispatch) = self.dispatch.as_mut() else { return; }; let (packet_id, packet) = self.oracle.packets.create(); let res = dispatch.send_control(queue_id, packet); self.oracle.on_control_dispatch(queue_id, packet_id, res); } fn send_stream(&mut self, queue_id: VarInt, inject: bool) { if inject { return self.oracle.send_stream_inject(queue_id); } let Some(dispatch) = self.dispatch.as_mut() else { return; }; let (packet_id, packet) = self.oracle.packets.create(); let res = dispatch.send_stream(queue_id, packet); self.oracle.on_stream_dispatch(queue_id, packet_id, res); } } struct Oracle { stream: BTreeMap<VarInt, Stream>, control: BTreeMap<VarInt, Control>, packets: Packets, } impl Oracle { fn new(packets: Packets) -> Self { Self { packets, stream: Default::default(), control: Default::default(), } } fn on_alloc(&mut self, control: Control, stream: Stream) { let queue_id = control.queue_id(); assert_eq!(queue_id, stream.queue_id(), "queue IDs should match"); assert!( control.try_recv().unwrap().is_none(), "queue should be empty" ); assert!( stream.try_recv().unwrap().is_none(), "queue should be empty" ); assert!( self.control.insert(queue_id, control).is_none(), "queue ID should be unique" ); assert!( self.stream.insert(queue_id, stream).is_none(), "queue ID should be unique" ); } fn on_control_dispatch( &mut self, idx: VarInt, packet_id: u64, result: Result<Option<desc::Filled>, Error>, ) { let Some(channel) = self.control.get(&idx) else { assert!(result.is_err()); return; }; assert!(result.is_ok()); let actual = channel.try_recv().unwrap().unwrap(); assert_eq!( actual.payload(), packet_id.to_be_bytes(), "queue should contain expected packet id" ); assert!( channel.try_recv().unwrap().is_none(), "queue should be empty now" ); } fn on_stream_dispatch( &mut self, idx: VarInt, packet_id: u64, result: Result<Option<desc::Filled>, Error>, ) { let Some(channel) = self.stream.get(&idx) else { assert!(result.is_err()); return; }; assert!(result.is_ok()); let actual = channel.try_recv().unwrap().unwrap(); assert_eq!( actual.payload(), packet_id.to_be_bytes(), "queue should contain expected packet id" ); assert!( channel.try_recv().unwrap().is_none(), "queue should be empty now" ); } fn send_stream_inject(&mut self, idx: VarInt) { let Some(channel) = self .stream .get(&idx) .or_else(|| self.stream.first_key_value().map(|(_k, v)| v)) else { return; }; let (packet_id, packet) = self.packets.create(); assert!(channel.push(packet).is_none(), "queue should accept packet"); let actual = channel.try_recv().unwrap().unwrap(); assert_eq!( actual.payload(), packet_id.to_be_bytes(), "queue should contain expected packet id" ); if matches!(channel.try_recv(), Ok(Some(_))) { panic!("queue should be empty or errored"); } } } #[derive(Clone)] struct Packets { packets: recv::pool::Pool, packet_id: u64, } impl Default for Packets { fn default() -> Self { Self { packets: recv::pool::Pool::new(8, 8), packet_id: Default::default(), } } } impl Packets { fn create(&mut self) -> (u64, recv::descriptor::Filled) { let packet_id = self.packet_id; self.packet_id += 1; let unfilled = self.packets.alloc_or_grow(); let packet = unfilled .recv_with(|_addr, _cmsg, mut payload| { let v = packet_id.to_be_bytes(); payload[..v.len()].copy_from_slice(&v); <std::io::Result<_>>::Ok(v.len()) }) .unwrap() .next() .unwrap(); (packet_id, packet) } } #[test] fn model_test() { crate::testing::init_tracing(); // create a Packet allocator once to avoid setup/teardown costs let packets = AssertUnwindSafe(Packets::default()); check!() .with_type::<(bool, Vec<Op>)>() .with_test_time(core::time::Duration::from_secs(30)) .for_each(move |(non_zero, ops)| { let mut model = Model::new(packets.clone(), *non_zero); for op in ops { model.apply(op); } }); } /// ensure that freeing an allocator notifies all of the open receivers #[test] fn alloc_drop_notify() { sim(|| { let stream_cap = 1; let control_cap = 1; let mut alloc = Allocator::new(stream_cap, control_cap); for _ in 0..2 { let (stream, control) = alloc.alloc_or_grow(None); async move { stream.recv(Actor::Application).await.unwrap_err(); } .primary() .spawn(); async move { control.recv(Actor::Application).await.unwrap_err(); } .primary() .spawn(); } async move { core::time::Duration::from_millis(100).sleep().await; drop(alloc); } .spawn(); }); } #[test] fn associated_credentials() { check!().exhaustive().run(|| { let mut alloc = Allocator::new(1, 1); let alloc = &mut alloc; let mut key_id = VarInt::ZERO; let key_id = &mut key_id; let mut alloc_one = move || { let creds = Credentials { id: credentials::Id::default(), key_id: *key_id, }; *key_id += 1; let (stream, control) = alloc.alloc_or_grow(Some(&creds)); let keys = alloc.pool.keys(); move || { assert_eq!(keys.get(&creds), Some(control.queue_id())); assert_eq!(keys.get(&creds), Some(stream.queue_id())); // interleave the order in which we drop channels if bolero::any() { drop(stream); assert_eq!( keys.get(&creds), Some(control.queue_id()), "credentials should still match" ); drop(control); } else { drop(control); assert_eq!( keys.get(&creds), Some(stream.queue_id()), "credentials should still match" ); drop(stream); } assert_eq!(keys.get(&creds), None, "credentials should be removed"); } }; let mut channels = [alloc_one(), alloc_one()]; // change the order that channels are freed channels.shuffle(); for channel in channels { channel(); } }); } #[test] fn stress_test() { use std::{ sync::{ atomic::{AtomicBool, Ordering}, mpsc::sync_channel as channel, }, thread::{scope, sleep, yield_now}, time::Duration, }; crate::testing::init_tracing(); let mut alloc = Allocator::new(1, 1); let (stream_send, stream_recv) = channel(10); let (control_send, control_recv) = channel(10); scope(|s| { static IS_OPEN: AtomicBool = AtomicBool::new(true); s.spawn(|| { sleep(Duration::from_secs(1)); IS_OPEN.store(false, Ordering::Relaxed); }); let alloc = s.spawn(move || { while IS_OPEN.load(Ordering::Relaxed) { let (control, stream) = alloc.alloc_or_grow(None); stream_send.send(stream).unwrap(); control_send.send(control).unwrap(); } }); s.spawn(move || { while let Ok(stream) = stream_recv.recv() { yield_now(); drop(stream); } }); s.spawn(move || { while let Ok(control) = control_recv.recv() { yield_now(); drop(control); } }); alloc.join().unwrap(); }); }