use std::{ collections::{vec_deque, VecDeque}, fmt, iter::{self, FromIterator as _}, mem, net::{IpAddr, SocketAddr}, }; /// Parse a host into parts (hostname and port). pub trait Address: Unpin + 'static { /// Get hostname part. fn hostname(&self) -> &str; /// Get optional port part. fn port(&self) -> Option { None } } impl Address for String { fn hostname(&self) -> &str { &self } } impl Address for &'static str { fn hostname(&self) -> &str { self } } #[derive(Debug, Eq, PartialEq, Hash)] pub(crate) enum ConnectAddrs { None, One(SocketAddr), Multi(VecDeque), } impl ConnectAddrs { pub(crate) fn is_none(&self) -> bool { matches!(self, Self::None) } pub(crate) fn is_some(&self) -> bool { !self.is_none() } } impl Default for ConnectAddrs { fn default() -> Self { Self::None } } impl From> for ConnectAddrs { fn from(addr: Option) -> Self { match addr { Some(addr) => ConnectAddrs::One(addr), None => ConnectAddrs::None, } } } /// Connection info. #[derive(Debug, PartialEq, Eq, Hash)] pub struct Connect { pub(crate) req: T, pub(crate) port: u16, pub(crate) addr: ConnectAddrs, pub(crate) local_addr: Option, } impl Connect { /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 pub fn new(req: T) -> Connect { let (_, port) = parse_host(req.hostname()); Connect { req, port: port.unwrap_or(0), addr: ConnectAddrs::None, local_addr: None, } } /// Create new `Connect` instance from host and address. Connector skips name resolution stage /// for such connect messages. pub fn with_addr(req: T, addr: SocketAddr) -> Connect { Connect { req, port: 0, addr: ConnectAddrs::One(addr), local_addr: None, } } /// Use port if address does not provide one. /// /// Default value is 0. pub fn set_port(mut self, port: u16) -> Self { self.port = port; self } /// Set address. pub fn set_addr(mut self, addr: Option) -> Self { self.addr = ConnectAddrs::from(addr); self } /// Set list of addresses. pub fn set_addrs(mut self, addrs: I) -> Self where I: IntoIterator, { let mut addrs = VecDeque::from_iter(addrs); self.addr = if addrs.len() < 2 { ConnectAddrs::from(addrs.pop_front()) } else { ConnectAddrs::Multi(addrs) }; self } /// Set local_addr of connect. pub fn set_local_addr(mut self, addr: impl Into) -> Self { self.local_addr = Some(addr.into()); self } /// Get hostname. pub fn hostname(&self) -> &str { self.req.hostname() } /// Get request port. pub fn port(&self) -> u16 { self.req.port().unwrap_or(self.port) } /// Get resolved request addresses. pub fn addrs(&self) -> ConnectAddrsIter<'_> { match self.addr { ConnectAddrs::None => ConnectAddrsIter::None, ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), ConnectAddrs::Multi(ref addrs) => ConnectAddrsIter::Multi(addrs.iter()), } } /// Take resolved request addresses. pub fn take_addrs(&mut self) -> ConnectAddrsIter<'static> { match mem::take(&mut self.addr) { ConnectAddrs::None => ConnectAddrsIter::None, ConnectAddrs::One(addr) => ConnectAddrsIter::One(addr), ConnectAddrs::Multi(addrs) => ConnectAddrsIter::MultiOwned(addrs.into_iter()), } } } impl From for Connect { fn from(addr: T) -> Self { Connect::new(addr) } } impl fmt::Display for Connect { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}:{}", self.hostname(), self.port()) } } /// Iterator over addresses in a [`Connect`] request. #[derive(Clone)] pub enum ConnectAddrsIter<'a> { None, One(SocketAddr), Multi(vec_deque::Iter<'a, SocketAddr>), MultiOwned(vec_deque::IntoIter), } impl Iterator for ConnectAddrsIter<'_> { type Item = SocketAddr; fn next(&mut self) -> Option { match *self { Self::None => None, Self::One(addr) => { *self = Self::None; Some(addr) } Self::Multi(ref mut iter) => iter.next().copied(), Self::MultiOwned(ref mut iter) => iter.next(), } } fn size_hint(&self) -> (usize, Option) { match *self { Self::None => (0, Some(0)), Self::One(_) => (1, Some(1)), Self::Multi(ref iter) => iter.size_hint(), Self::MultiOwned(ref iter) => iter.size_hint(), } } } impl fmt::Debug for ConnectAddrsIter<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_list().entries(self.clone()).finish() } } impl iter::ExactSizeIterator for ConnectAddrsIter<'_> {} impl iter::FusedIterator for ConnectAddrsIter<'_> {} pub struct Connection { io: U, req: T, } impl Connection { pub fn new(io: U, req: T) -> Self { Self { io, req } } } impl Connection { /// Reconstruct from a parts. pub fn from_parts(io: U, req: T) -> Self { Self { io, req } } /// Deconstruct into a parts. pub fn into_parts(self) -> (U, T) { (self.io, self.req) } /// Replace inclosed object, return new Stream and old object pub fn replace_io(self, io: Y) -> (U, Connection) { (self.io, Connection { io, req: self.req }) } /// Returns a shared reference to the underlying stream. pub fn io_ref(&self) -> &U { &self.io } /// Returns a mutable reference to the underlying stream. pub fn io_mut(&mut self) -> &mut U { &mut self.io } } impl Connection { /// Get hostname. pub fn host(&self) -> &str { self.req.hostname() } } impl std::ops::Deref for Connection { type Target = U; fn deref(&self) -> &U { &self.io } } impl std::ops::DerefMut for Connection { fn deref_mut(&mut self) -> &mut U { &mut self.io } } impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Stream {{{:?}}}", self.io) } } fn parse_host(host: &str) -> (&str, Option) { let mut parts_iter = host.splitn(2, ':'); match parts_iter.next() { Some(hostname) => { let port_str = parts_iter.next().unwrap_or(""); let port = port_str.parse::().ok(); (hostname, port) } None => (host, None), } } #[cfg(test)] mod tests { use std::net::Ipv4Addr; use super::*; #[test] fn test_host_parser() { assert_eq!(parse_host("example.com"), ("example.com", None)); assert_eq!(parse_host("example.com:8080"), ("example.com", Some(8080))); assert_eq!(parse_host("example:8080"), ("example", Some(8080))); assert_eq!(parse_host("example.com:false"), ("example.com", None)); assert_eq!(parse_host("example.com:false:false"), ("example.com", None)); } #[test] fn test_addr_iter_multi() { let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); let unspecified = SocketAddr::from((IpAddr::from(Ipv4Addr::UNSPECIFIED), 8080)); let mut addrs = VecDeque::new(); addrs.push_back(localhost); addrs.push_back(unspecified); let mut iter = ConnectAddrsIter::Multi(addrs.iter()); assert_eq!(iter.next(), Some(localhost)); assert_eq!(iter.next(), Some(unspecified)); assert_eq!(iter.next(), None); let mut iter = ConnectAddrsIter::MultiOwned(addrs.into_iter()); assert_eq!(iter.next(), Some(localhost)); assert_eq!(iter.next(), Some(unspecified)); assert_eq!(iter.next(), None); } #[test] fn test_addr_iter_single() { let localhost = SocketAddr::from((IpAddr::from(Ipv4Addr::LOCALHOST), 8080)); let mut iter = ConnectAddrsIter::One(localhost); assert_eq!(iter.next(), Some(localhost)); assert_eq!(iter.next(), None); let mut iter = ConnectAddrsIter::None; assert_eq!(iter.next(), None); } #[test] fn test_local_addr() { let conn = Connect::new("hello").set_local_addr([127, 0, 0, 1]); assert_eq!( conn.local_addr.unwrap(), IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) ) } }