src/net.rs (48 lines of code) (raw):

// Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! The module is used to provide abstraction over TCP socket and UDS. use std::fmt; #[cfg(any(target_os = "linux", target_os = "android"))] use std::os::linux::net::SocketAddrExt; use futures::{Future, TryFutureExt}; use tokio::io::{AsyncRead, AsyncWrite}; // A unify version of `std::net::SocketAddr` and Unix domain socket. #[derive(Debug)] pub enum SocketAddr { Net(std::net::SocketAddr), // This could work on Windows in the future. See also rust-lang/rust#56533. #[cfg(unix)] Unix(std::path::PathBuf), #[cfg(any(target_os = "linux", target_os = "android"))] UnixAbstract(Vec<u8>), } impl fmt::Display for SocketAddr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { SocketAddr::Net(addr) => write!(f, "{}", addr), #[cfg(unix)] SocketAddr::Unix(p) => write!(f, "{}", p.display()), #[cfg(any(target_os = "linux", target_os = "android"))] SocketAddr::UnixAbstract(p) => write!(f, "\\x00{}", p.escape_ascii()), } } } impl SocketAddr { /// Get a Net address that with IP part set to "127.0.0.1". #[inline] pub fn with_port(port: u16) -> Self { SocketAddr::Net(std::net::SocketAddr::from(([127, 0, 0, 1], port))) } #[inline] pub fn as_net(&self) -> Option<&std::net::SocketAddr> { match self { SocketAddr::Net(addr) => Some(addr), #[cfg(unix)] _ => None, } } /// Parse a string as a unix domain socket. /// /// The string should follow the format of `self.to_string()`. #[cfg(unix)] pub fn parse_uds(s: &str) -> std::io::Result<Self> { // Parse abstract socket address first as it can contain any chars. #[cfg(any(target_os = "linux", target_os = "android"))] { if s.starts_with("\\x00") { // Rust abstract path expects no prepand '\x00'. let data = crate::util::ascii_unescape_default(&s.as_bytes()[4..])?; return Ok(SocketAddr::UnixAbstract(data)); } } let path = std::path::PathBuf::from(s); Ok(SocketAddr::Unix(path)) } #[cfg(unix)] pub fn is_unix_path(&self) -> bool { matches!(self, SocketAddr::Unix(_)) } #[cfg(not(unix))] pub fn is_unix_path(&self) -> bool { false } } // A helper trait to unify the behavior of TCP and UDS listener. pub trait Acceptor { type Socket: AsyncRead + AsyncWrite + Unpin + Send; fn accept(&self) -> impl Future<Output = tokio::io::Result<Self::Socket>> + Send; fn local_addr(&self) -> tokio::io::Result<Option<SocketAddr>>; } impl Acceptor for tokio::net::TcpListener { type Socket = tokio::net::TcpStream; #[inline] fn accept(&self) -> impl Future<Output = tokio::io::Result<Self::Socket>> + Send { tokio::net::TcpListener::accept(self).and_then(|(s, _)| futures::future::ok(s)) } #[inline] fn local_addr(&self) -> tokio::io::Result<Option<SocketAddr>> { tokio::net::TcpListener::local_addr(self).map(|a| Some(SocketAddr::Net(a))) } } // A helper trait to unify the behavior of TCP and UDS stream. pub trait Connection: std::io::Read + std::io::Write { fn try_clone(&self) -> std::io::Result<Box<dyn Connection>>; } impl Connection for std::net::TcpStream { #[inline] fn try_clone(&self) -> std::io::Result<Box<dyn Connection>> { let stream = std::net::TcpStream::try_clone(self)?; Ok(Box::new(stream)) } } // Helper function to create a stream. Uses dynamic dispatch to make code more // readable. pub fn connect(addr: &SocketAddr) -> std::io::Result<Box<dyn Connection>> { match addr { SocketAddr::Net(addr) => { std::net::TcpStream::connect(addr).map(|s| Box::new(s) as Box<dyn Connection>) } #[cfg(unix)] SocketAddr::Unix(p) => { std::os::unix::net::UnixStream::connect(p).map(|s| Box::new(s) as Box<dyn Connection>) } #[cfg(any(target_os = "linux", target_os = "android"))] SocketAddr::UnixAbstract(p) => { let sock = std::os::unix::net::SocketAddr::from_abstract_name(p)?; std::os::unix::net::UnixStream::connect_addr(&sock) .map(|s| Box::new(s) as Box<dyn Connection>) } } } #[cfg(unix)] mod unix_imp { use futures::TryFutureExt; use super::*; impl Acceptor for tokio::net::UnixListener { type Socket = tokio::net::UnixStream; #[inline] fn accept(&self) -> impl Future<Output = tokio::io::Result<Self::Socket>> + Send { tokio::net::UnixListener::accept(self).and_then(|(s, _)| futures::future::ok(s)) } #[inline] fn local_addr(&self) -> tokio::io::Result<Option<SocketAddr>> { let addr = tokio::net::UnixListener::local_addr(self)?; if let Some(p) = addr.as_pathname() { return Ok(Some(SocketAddr::Unix(p.to_path_buf()))); } // TODO: support get addr from abstract socket. // tokio::net::SocketAddr needs to support `as_abstract_name`. // #[cfg(any(target_os = "linux", target_os = "android"))] // if let Some(p) = addr.0.as_abstract_name() { // return Ok(SocketAddr::UnixAbstract(p.to_vec())); // } Ok(None) } } impl Connection for std::os::unix::net::UnixStream { #[inline] fn try_clone(&self) -> std::io::Result<Box<dyn Connection>> { let stream = std::os::unix::net::UnixStream::try_clone(self)?; Ok(Box::new(stream)) } } }