netbench-driver-s2n-quic/src/lib.rs (73 lines of code) (raw):

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use bytes::Bytes; use core::{ future::Future, pin::Pin, task::{Context, Poll}, }; use netbench::{client, connection::Owner, helper::IdPrefixReader, scenario, Driver, Result}; use s2n_quic::{ connection, stream::{LocalStream, PeerStream, SplittableStream}, }; use s2n_quic_core::stream::testing::Data; use std::task::ready; use std::{ collections::{hash_map::Entry, HashMap}, ops, sync::Arc, }; fn stream_error(err: s2n_quic::stream::Error) -> Result<()> { if let s2n_quic::stream::Error::StreamReset { error, .. } = err { if *error == 0 { return Ok(()); } } if let s2n_quic::stream::Error::ConnectionError { error, .. } = err { return conn_error(error); } Err(err.into()) } fn conn_error(err: s2n_quic::connection::Error) -> Result<()> { if let s2n_quic::connection::Error::Application { error, .. } = err { if *error == 0 { return Ok(()); } } Err(err.into()) } pub struct Client(pub s2n_quic::Client); impl ops::Deref for Client { type Target = s2n_quic::Client; #[inline] fn deref(&self) -> &Self::Target { &self.0 } } impl ops::DerefMut for Client { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } impl<'a> client::Client<'a> for Client { type Connect = Connect<'a>; type Connection = Driver<'a, Connection>; fn connect( &mut self, addr: std::net::SocketAddr, server_name: &str, _server_conn_id: u64, scenario: &'a Arc<scenario::Connection>, ) -> Self::Connect { let connect = s2n_quic::client::Connect::new(addr).with_server_name(server_name); let attempt = s2n_quic::Client::connect(self, connect); Connect { attempt, scenario } } } pub struct Connect<'a> { attempt: s2n_quic::client::ConnectionAttempt, scenario: &'a scenario::Connection, } impl<'a> Future for Connect<'a> { type Output = Result<crate::Driver<'a, Connection>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let conn = ready!(Pin::new(&mut self.attempt).poll(cx))?; let conn = Connection::new(conn); let conn = crate::Driver::new(self.scenario, conn); Ok(conn).into() } } pub struct Connection { conn: s2n_quic::Connection, streams: [HashMap<u64, Stream>; 2], opened_streams: HashMap<u64, (Bytes, LocalStream)>, unidentified_peer_stream: Option<(IdPrefixReader, PeerStream)>, } impl From<s2n_quic::Connection> for Connection { fn from(conn: s2n_quic::Connection) -> Self { Self::new(conn) } } impl Connection { pub fn new(connection: s2n_quic::Connection) -> Self { Self { conn: connection, streams: [HashMap::new(), HashMap::new()], opened_streams: HashMap::new(), unidentified_peer_stream: Default::default(), } } pub fn into_inner(self) -> s2n_quic::Connection { self.conn } fn open_local_stream< F: FnOnce(&mut s2n_quic::Connection, &mut Context) -> Poll<Result<S, connection::Error>>, S: Into<LocalStream>, >( &mut self, id: u64, open: F, cx: &mut Context, ) -> Poll<Result<()>> { // the stream has already been opened and is waiting to send the prefix if let Entry::Occupied(mut entry) = self.opened_streams.entry(id) { let (prefix, stream) = entry.get_mut(); return match stream.poll_send(prefix, cx) { Poll::Ready(Ok(_)) => { let (_, stream) = entry.remove(); let stream = Stream::new(stream); self.streams[Owner::Local].insert(id, stream); Poll::Ready(Ok(())) } Poll::Ready(Err(err)) => { entry.remove(); Poll::Ready(stream_error(err)) } Poll::Pending => Poll::Pending, }; } let mut stream = ready!(open(&mut self.conn, cx))?.into(); let mut prefix = Bytes::copy_from_slice(&id.to_be_bytes()); match stream.poll_send(&mut prefix, cx) { Poll::Ready(Ok(_)) => { let stream = Stream::new(stream); self.streams[Owner::Local].insert(id, stream); Poll::Ready(Ok(())) } Poll::Ready(Err(err)) => Poll::Ready(stream_error(err)), Poll::Pending => { self.opened_streams.insert(id, (prefix, stream)); Poll::Pending } } } } impl netbench::Connection for Connection { fn id(&self) -> u64 { self.conn.id() } fn poll_open_bidirectional_stream(&mut self, id: u64, cx: &mut Context) -> Poll<Result<()>> { self.open_local_stream(id, |conn, cx| conn.poll_open_bidirectional_stream(cx), cx) } fn poll_open_send_stream(&mut self, id: u64, cx: &mut Context) -> Poll<Result<()>> { self.open_local_stream(id, |conn, cx| conn.poll_open_send_stream(cx), cx) } fn poll_accept_stream(&mut self, cx: &mut Context) -> Poll<Result<Option<u64>>> { loop { if let Some((id, stream)) = self.unidentified_peer_stream.as_mut() { let len = ready!(futures::io::AsyncRead::poll_read( Pin::new(stream), cx, id.remaining() ))?; let id = ready!(id.on_read(len)); let (_, stream) = self.unidentified_peer_stream.take().unwrap(); let stream = Stream::new(stream); self.streams[Owner::Remote].insert(id, stream); return Poll::Ready(Ok(Some(id))); } let stream = ready!(self.conn.poll_accept(cx)); if let Ok(Some(stream)) = stream { self.unidentified_peer_stream = Some((Default::default(), stream)); } else { return Poll::Ready(Ok(None)); }; } } fn poll_send( &mut self, owner: Owner, id: u64, bytes: u64, cx: &mut Context, ) -> Poll<Result<u64>> { self.streams[owner] .get_mut(&id) .unwrap() .tx .as_mut() .unwrap() .poll_send(bytes, cx) } fn poll_receive( &mut self, owner: Owner, id: u64, bytes: u64, cx: &mut Context, ) -> Poll<Result<u64>> { self.streams[owner] .get_mut(&id) .unwrap() .rx .as_mut() .unwrap() .poll_receive(bytes, cx) } fn poll_send_finish(&mut self, owner: Owner, id: u64, _cx: &mut Context) -> Poll<Result<()>> { if let Entry::Occupied(mut entry) = self.streams[owner].entry(id) { let stream = entry.get_mut(); if let Some(mut stream) = stream.tx.take() { stream.inner.finish().or_else(stream_error)?; } if stream.rx.is_none() { entry.remove(); } } Poll::Ready(Ok(())) } fn poll_receive_finish( &mut self, owner: Owner, id: u64, _cx: &mut Context, ) -> Poll<Result<()>> { if let Entry::Occupied(mut entry) = self.streams[owner].entry(id) { let stream = entry.get_mut(); if let Some(mut stream) = stream.rx.take() { let _ = stream.inner.stop_sending(0u8.into()); } if stream.tx.is_none() { entry.remove(); } } Poll::Ready(Ok(())) } } macro_rules! chunks { () => { [ Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), ] }; } struct Stream { rx: Option<ReceiveStream>, tx: Option<SendStream>, } impl Stream { fn new(stream: impl SplittableStream) -> Self { let (rx, tx) = stream.split(); let rx = rx.map(ReceiveStream::new); let tx = tx.map(SendStream::new); Self { rx, tx } } } struct ReceiveStream { inner: s2n_quic::stream::ReceiveStream, buffered: u64, is_open: bool, } impl ReceiveStream { fn new(inner: s2n_quic::stream::ReceiveStream) -> Self { Self { inner, buffered: 0, is_open: true, } } fn poll_receive(&mut self, bytes: u64, cx: &mut Context) -> Poll<Result<u64>> { if !self.is_open && self.buffered == 0 { return Ok(0).into(); } while self.buffered <= bytes && self.is_open { let mut chunks = chunks!(); if let Poll::Ready(res) = self.inner.poll_receive_vectored(&mut chunks, cx) { let (count, is_open) = res?; self.is_open &= is_open; for chunk in &chunks[..count] { self.buffered += chunk.len() as u64; } } else { break; } } let received_len = bytes.min(self.buffered); self.buffered -= received_len; if !self.is_open && received_len == 0 { return Ok(0).into(); } if received_len == 0 { Poll::Pending } else { Ok(received_len).into() } } } struct SendStream { inner: s2n_quic::stream::SendStream, data: Data, } impl SendStream { fn new(inner: s2n_quic::stream::SendStream) -> Self { Self { inner, data: Data::new(u64::MAX), } } fn poll_send(&mut self, mut bytes: u64, cx: &mut Context) -> Poll<Result<u64>> { if bytes == 0 { return Ok(0).into(); } let mut len = 0; let mut data = self.data; while bytes > 0 { let mut chunks = chunks!(); let count = data.send(bytes as usize, &mut chunks).unwrap(); let initial_len: u64 = chunks.iter().map(|chunk| chunk.len() as u64).sum(); let count = if let Poll::Ready(count) = self.inner.poll_send_vectored(&mut chunks[..count], cx)? { count } else { break; }; if count == chunks.len() { len += initial_len; bytes -= initial_len; continue; } let remaining_len: u64 = chunks[count..].iter().map(|chunk| chunk.len() as u64).sum(); len += initial_len - remaining_len; break; } if len == 0 { return Poll::Pending; } self.data.seek_forward(len as usize); Poll::Ready(Ok(len)) } }