diff --git a/actix-rt/src/lib.rs b/actix-rt/src/lib.rs index a529bdb0..e21cd651 100644 --- a/actix-rt/src/lib.rs +++ b/actix-rt/src/lib.rs @@ -76,7 +76,7 @@ pub mod net { use tokio::io::{AsyncRead, AsyncWrite}; pub use tokio::net::UdpSocket; - pub use tokio::net::{TcpListener, TcpStream}; + pub use tokio::net::{TcpListener, TcpSocket, TcpStream}; #[cfg(unix)] pub use tokio::net::{UnixDatagram, UnixListener, UnixStream}; diff --git a/actix-tls/src/connect/connect.rs b/actix-tls/src/connect/connect.rs index 9e5d417f..bd4b3fdf 100755 --- a/actix-tls/src/connect/connect.rs +++ b/actix-tls/src/connect/connect.rs @@ -3,7 +3,7 @@ use std::{ fmt, iter::{self, FromIterator as _}, mem, - net::SocketAddr, + net::{IpAddr, SocketAddr}, }; /// Parse a host into parts (hostname and port). @@ -67,6 +67,7 @@ pub struct Connect { pub(crate) req: T, pub(crate) port: u16, pub(crate) addr: ConnectAddrs, + pub(crate) local_addr: Option, } impl Connect { @@ -78,6 +79,7 @@ impl Connect { req, port: port.unwrap_or(0), addr: ConnectAddrs::None, + local_addr: None, } } @@ -88,6 +90,7 @@ impl Connect { req, port: 0, addr: ConnectAddrs::One(addr), + local_addr: None, } } @@ -119,6 +122,12 @@ impl Connect { self } + /// Set local_addr of connect. + pub fn set_local_addr(mut self, addr: impl Into) -> Self { + self.local_addr = Some(addr.into()); + self + } + /// Get hostname. pub fn hostname(&self) -> &str { self.req.hostname() @@ -285,7 +294,7 @@ fn parse_host(host: &str) -> (&str, Option) { #[cfg(test)] mod tests { - use std::net::{IpAddr, Ipv4Addr}; + use std::net::Ipv4Addr; use super::*; @@ -329,4 +338,13 @@ mod tests { let mut iter = ConnectAddrsIter::None; assert_eq!(iter.next(), None); } + + #[test] + fn test_local_addr() { + let conn = Connect::new("hello").set_local_addr([127, 0, 0, 1]); + assert_eq!( + conn.local_addr.unwrap(), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) + ) + } } diff --git a/actix-tls/src/connect/connector.rs b/actix-tls/src/connect/connector.rs index 9acb1dd5..4d98ba78 100755 --- a/actix-tls/src/connect/connector.rs +++ b/actix-tls/src/connect/connector.rs @@ -2,12 +2,12 @@ use std::{ collections::VecDeque, future::Future, io, - net::SocketAddr, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}, pin::Pin, task::{Context, Poll}, }; -use actix_rt::net::TcpStream; +use actix_rt::net::{TcpSocket, TcpStream}; use actix_service::{Service, ServiceFactory}; use futures_core::{future::LocalBoxFuture, ready}; use log::{error, trace}; @@ -54,9 +54,14 @@ impl Service> for TcpConnector { fn call(&self, req: Connect) -> Self::Future { let port = req.port(); - let Connect { req, addr, .. } = req; + let Connect { + req, + addr, + local_addr, + .. + } = req; - TcpConnectorResponse::new(req, port, addr) + TcpConnectorResponse::new(req, port, local_addr, addr) } } @@ -65,6 +70,7 @@ pub enum TcpConnectorResponse { Response { req: Option, port: u16, + local_addr: Option, addrs: Option>, stream: Option>>, }, @@ -72,7 +78,12 @@ pub enum TcpConnectorResponse { } impl TcpConnectorResponse { - pub(crate) fn new(req: T, port: u16, addr: ConnectAddrs) -> TcpConnectorResponse { + pub(crate) fn new( + req: T, + port: u16, + local_addr: Option, + addr: ConnectAddrs, + ) -> TcpConnectorResponse { if addr.is_none() { error!("TCP connector: unresolved connection address"); return TcpConnectorResponse::Error(Some(ConnectError::Unresolved)); @@ -90,8 +101,9 @@ impl TcpConnectorResponse { ConnectAddrs::One(addr) => TcpConnectorResponse::Response { req: Some(req), port, + local_addr, addrs: None, - stream: Some(ReusableBoxFuture::new(TcpStream::connect(addr))), + stream: Some(ReusableBoxFuture::new(connect(addr, local_addr))), }, // when resolver returns multiple socket addr for request they would be popped from @@ -99,6 +111,7 @@ impl TcpConnectorResponse { ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { req: Some(req), port, + local_addr, addrs: Some(addrs), stream: None, }, @@ -116,6 +129,7 @@ impl Future for TcpConnectorResponse { TcpConnectorResponse::Response { req, port, + local_addr, addrs, stream, } => loop { @@ -148,11 +162,40 @@ impl Future for TcpConnectorResponse { // try to connect let addr = addrs.as_mut().unwrap().pop_front().unwrap(); + let fut = connect(addr, *local_addr); match stream { - Some(rbf) => rbf.set(TcpStream::connect(addr)), - None => *stream = Some(ReusableBoxFuture::new(TcpStream::connect(addr))), + Some(rbf) => rbf.set(fut), + None => *stream = Some(ReusableBoxFuture::new(fut)), } }, } } } + +async fn connect(addr: SocketAddr, local_addr: Option) -> io::Result { + // use local addr if connect asks for it. + match local_addr { + Some(ip_addr) => { + let socket = match ip_addr { + IpAddr::V4(ip_addr) => { + let socket = TcpSocket::new_v4()?; + let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0)); + socket.bind(addr)?; + socket + } + IpAddr::V6(ip_addr) => { + let socket = TcpSocket::new_v6()?; + let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0)); + socket.bind(addr)?; + socket + } + }; + + socket.set_reuseaddr(true)?; + + socket.connect(addr).await + } + + None => TcpStream::connect(addr).await, + } +} diff --git a/actix-tls/tests/test_connect.rs b/actix-tls/tests/test_connect.rs index 7ee7afda..e8e23757 100755 --- a/actix-tls/tests/test_connect.rs +++ b/actix-tls/tests/test_connect.rs @@ -1,4 +1,7 @@ -use std::io; +use std::{ + io, + net::{IpAddr, Ipv4Addr}, +}; use actix_codec::{BytesCodec, Framed}; use actix_rt::net::TcpStream; @@ -125,3 +128,25 @@ async fn test_rustls_uri() { let con = conn.call(addr.into()).await.unwrap(); assert_eq!(con.peer_addr().unwrap(), srv.addr()); } + +#[actix_rt::test] +async fn test_local_addr() { + let srv = TestServer::with(|| { + fn_service(|io: TcpStream| async { + let mut framed = Framed::new(io, BytesCodec); + framed.send(Bytes::from_static(b"test")).await?; + Ok::<_, io::Error>(()) + }) + }); + + let conn = actix_connect::default_connector(); + let local = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3)); + + let (con, _) = conn + .call(Connect::with_addr("10", srv.addr()).set_local_addr(local)) + .await + .unwrap() + .into_parts(); + + assert_eq!(con.local_addr().unwrap().ip(), local) +}