quic/s2n-quic-platform/src/io/tokio.rs (244 lines of code) (raw):
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::{features::gso, message::default as message, socket, syscall};
use s2n_quic_core::{
endpoint::Endpoint,
event::{self, EndpointPublisher as _},
inet::{self, SocketAddress},
io::event_loop::EventLoop,
path::{mtu, MaxMtu},
task::cooldown::Cooldown,
time::Clock as ClockTrait,
};
use std::{convert::TryInto, io, io::ErrorKind};
use tokio::runtime::Handle;
mod builder;
mod clock;
pub(crate) mod task;
#[cfg(test)]
mod tests;
pub type PathHandle = message::Handle;
pub use builder::Builder;
pub(crate) use clock::Clock;
#[derive(Debug, Default)]
pub struct Io {
builder: Builder,
}
impl Io {
pub fn builder() -> Builder {
Builder::default()
}
pub fn new<A: std::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
let address = addr.to_socket_addrs()?.next().expect("missing address");
let builder = Builder::default().with_receive_address(address)?;
Ok(Self { builder })
}
pub fn start<E: Endpoint<PathHandle = PathHandle>>(
self,
mut endpoint: E,
) -> io::Result<(tokio::task::JoinHandle<()>, SocketAddress)> {
let Builder {
handle,
rx_socket,
tx_socket,
recv_addr,
send_addr,
socket_recv_buffer_size,
socket_send_buffer_size,
queue_recv_buffer_size,
queue_send_buffer_size,
mtu_config_builder,
max_segments,
gro_enabled,
reuse_address,
reuse_port,
only_v6,
} = self.builder;
let clock = Clock::default();
let mut publisher = event::EndpointPublisherSubscriber::new(
event::builder::EndpointMeta {
endpoint_type: E::ENDPOINT_TYPE,
timestamp: clock.get_time(),
},
None,
endpoint.subscriber(),
);
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::Gso {
max_segments: max_segments.into(),
},
});
// try to use the tokio runtime handle if provided, otherwise try to use the implicit tokio
// runtime in the current scope of the application.
let handle = if let Some(handle) = handle {
handle
} else {
Handle::try_current().map_err(|err| std::io::Error::new(io::ErrorKind::Other, err))?
};
let guard = handle.enter();
let rx_socket = if let Some(rx_socket) = rx_socket {
rx_socket
} else if let Some(recv_addr) = recv_addr {
syscall::bind_udp(recv_addr, reuse_address, reuse_port, only_v6)?
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"missing bind address",
));
};
let rx_addr = convert_addr_to_std(rx_socket.local_addr()?)?;
let tx_socket = if let Some(tx_socket) = tx_socket {
tx_socket
} else if let Some(send_addr) = send_addr {
syscall::bind_udp(send_addr, reuse_address, reuse_port, only_v6)?
} else {
// No tx_socket or send address was specified, so the tx socket
// will be a handle to the rx socket.
rx_socket.try_clone()?
};
if let Some(size) = socket_send_buffer_size {
tx_socket.set_send_buffer_size(size)?;
}
if let Some(size) = socket_recv_buffer_size {
rx_socket.set_recv_buffer_size(size)?;
}
let mut mtu_config = mtu_config_builder
.build()
.map_err(|err| io::Error::new(ErrorKind::InvalidInput, format!("{err}")))?;
let original_max_mtu = mtu_config.max_mtu();
// Configure MTU discovery
if !syscall::configure_mtu_disc(&tx_socket) {
// disable MTU probing if we can't prevent fragmentation
mtu_config = mtu::Config::MIN;
}
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::BaseMtu {
mtu: mtu_config.base_mtu().into(),
},
});
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::InitialMtu {
mtu: mtu_config.initial_mtu().into(),
},
});
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::MaxMtu {
mtu: mtu_config.max_mtu().into(),
},
});
// Configure the socket with GRO
let gro_enabled = gro_enabled.unwrap_or(true) && syscall::configure_gro(&rx_socket);
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::Gro {
enabled: gro_enabled,
},
});
// Configure packet info CMSG
syscall::configure_pktinfo(&rx_socket);
// Configure TOS/ECN
let tos_enabled = syscall::configure_tos(&rx_socket);
publisher.on_platform_feature_configured(event::builder::PlatformFeatureConfigured {
configuration: event::builder::PlatformFeatureConfiguration::Ecn {
enabled: tos_enabled,
},
});
let (stats_sender, stats_recv) = crate::socket::stats::channel();
let rx = {
// if GRO is enabled, then we need to provide the syscall with the maximum size buffer
let payload_len = if gro_enabled {
u16::MAX
} else {
// Use the originally configured MTU to allow larger packets to be received
// even if the tx MTU has been reduced due to configure_mtu_disc failing
original_max_mtu.into()
} as u32;
let rx_buffer_size = queue_recv_buffer_size.unwrap_or(8 * (1 << 20));
let entries = rx_buffer_size / payload_len;
let entries = if entries.is_power_of_two() {
entries
} else {
// round up to the nearest power of two, since the ring buffers require it
entries.next_power_of_two()
};
let mut consumers = vec![];
let rx_socket_count = parse_env("S2N_QUIC_UNSTABLE_RX_SOCKET_COUNT").unwrap_or(1);
// configure the number of self-wakes before "cooling down" and waiting for epoll to
// complete
let rx_cooldown = cooldown("RX");
for idx in 0usize..rx_socket_count {
let (producer, consumer) = socket::ring::pair(entries, payload_len);
consumers.push(consumer);
// spawn a task that actually reads from the socket into the ring buffer
if idx + 1 == rx_socket_count {
handle.spawn(task::rx(
rx_socket,
producer,
rx_cooldown,
stats_sender.clone(),
));
break;
} else {
let rx_socket = rx_socket.try_clone()?;
handle.spawn(task::rx(
rx_socket,
producer,
rx_cooldown.clone(),
stats_sender.clone(),
));
}
}
// construct the RX side for the endpoint event loop
let max_mtu = MaxMtu::try_from(payload_len as u16).unwrap();
let addr: inet::SocketAddress = rx_addr.into();
socket::io::rx::Rx::new(consumers, max_mtu, addr.into())
};
let tx = {
let gso = crate::features::Gso::from(max_segments);
// compute the payload size for each message from the number of GSO segments we can
// fill
let payload_len = {
let max_mtu: u16 = mtu_config.max_mtu().into();
(max_mtu as u32 * gso.max_segments() as u32).min(u16::MAX as u32)
};
let tx_buffer_size = queue_send_buffer_size.unwrap_or(128 * 1024);
let entries = tx_buffer_size / payload_len;
let entries = if entries.is_power_of_two() {
entries
} else {
// round up to the nearest power of two, since the ring buffers require it
entries.next_power_of_two()
};
let mut producers = vec![];
let tx_socket_count = parse_env("S2N_QUIC_UNSTABLE_TX_SOCKET_COUNT").unwrap_or(1);
// configure the number of self-wakes before "cooling down" and waiting for epoll to
// complete
let tx_cooldown = cooldown("TX");
for idx in 0usize..tx_socket_count {
let (producer, consumer) = socket::ring::pair(entries, payload_len);
producers.push(producer);
// spawn a task that actually flushes the ring buffer to the socket
if idx + 1 == tx_socket_count {
handle.spawn(task::tx(
tx_socket,
consumer,
gso.clone(),
tx_cooldown,
stats_sender.clone(),
));
break;
} else {
let tx_socket = tx_socket.try_clone()?;
handle.spawn(task::tx(
tx_socket,
consumer,
gso.clone(),
tx_cooldown.clone(),
stats_sender.clone(),
));
}
}
// construct the TX side for the endpoint event loop
socket::io::tx::Tx::new(producers, gso, mtu_config.max_mtu())
};
// Notify the endpoint of the MTU that we chose
endpoint.set_mtu_config(mtu_config);
let task = handle.spawn(
EventLoop {
endpoint,
clock,
rx,
tx,
cooldown: cooldown("ENDPOINT"),
stats: stats_recv,
}
.start(rx_addr.into()),
);
drop(guard);
Ok((task, rx_addr.into()))
}
}
fn convert_addr_to_std(addr: socket2::SockAddr) -> io::Result<std::net::SocketAddr> {
addr.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid domain for socket"))
}
fn parse_env<T: core::str::FromStr>(name: &str) -> Option<T> {
std::env::var(name).ok().and_then(|v| v.parse().ok())
}
pub fn cooldown(direction: &str) -> Cooldown {
let name = format!("S2N_QUIC_UNSTABLE_COOLDOWN_{direction}");
let limit = parse_env(&name).unwrap_or(0);
Cooldown::new(limit)
}