1
0
mirror of https://github.com/fafhrd91/actix-net synced 2024-11-24 00:01:11 +01:00

add local_addr binding to connector service

This commit is contained in:
fakeshadow 2021-02-24 01:13:17 +08:00
parent 8d74cf387d
commit c3be839a69
4 changed files with 98 additions and 12 deletions

View File

@ -76,7 +76,7 @@ pub mod net {
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
pub use tokio::net::UdpSocket; pub use tokio::net::UdpSocket;
pub use tokio::net::{TcpListener, TcpStream}; pub use tokio::net::{TcpListener, TcpSocket, TcpStream};
#[cfg(unix)] #[cfg(unix)]
pub use tokio::net::{UnixDatagram, UnixListener, UnixStream}; pub use tokio::net::{UnixDatagram, UnixListener, UnixStream};

View File

@ -3,7 +3,7 @@ use std::{
fmt, fmt,
iter::{self, FromIterator as _}, iter::{self, FromIterator as _},
mem, mem,
net::SocketAddr, net::{IpAddr, SocketAddr},
}; };
/// Parse a host into parts (hostname and port). /// Parse a host into parts (hostname and port).
@ -67,6 +67,7 @@ pub struct Connect<T> {
pub(crate) req: T, pub(crate) req: T,
pub(crate) port: u16, pub(crate) port: u16,
pub(crate) addr: ConnectAddrs, pub(crate) addr: ConnectAddrs,
pub(crate) local_addr: Option<IpAddr>,
} }
impl<T: Address> Connect<T> { impl<T: Address> Connect<T> {
@ -78,6 +79,7 @@ impl<T: Address> Connect<T> {
req, req,
port: port.unwrap_or(0), port: port.unwrap_or(0),
addr: ConnectAddrs::None, addr: ConnectAddrs::None,
local_addr: None,
} }
} }
@ -88,6 +90,7 @@ impl<T: Address> Connect<T> {
req, req,
port: 0, port: 0,
addr: ConnectAddrs::One(addr), addr: ConnectAddrs::One(addr),
local_addr: None,
} }
} }
@ -119,6 +122,12 @@ impl<T: Address> Connect<T> {
self self
} }
/// Set local_addr of connect.
pub fn set_local_addr(mut self, addr: impl Into<IpAddr>) -> Self {
self.local_addr = Some(addr.into());
self
}
/// Get hostname. /// Get hostname.
pub fn hostname(&self) -> &str { pub fn hostname(&self) -> &str {
self.req.hostname() self.req.hostname()
@ -285,7 +294,7 @@ fn parse_host(host: &str) -> (&str, Option<u16>) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::net::{IpAddr, Ipv4Addr}; use std::net::Ipv4Addr;
use super::*; use super::*;
@ -329,4 +338,13 @@ mod tests {
let mut iter = ConnectAddrsIter::None; let mut iter = ConnectAddrsIter::None;
assert_eq!(iter.next(), None); assert_eq!(iter.next(), None);
} }
#[test]
fn test_local_addr() {
let conn = Connect::new("hello").set_local_addr([127, 0, 0, 1]);
assert_eq!(
conn.local_addr.unwrap(),
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
)
}
} }

View File

@ -2,12 +2,12 @@ use std::{
collections::VecDeque, collections::VecDeque,
future::Future, future::Future,
io, io,
net::SocketAddr, net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use actix_rt::net::TcpStream; use actix_rt::net::{TcpSocket, TcpStream};
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_core::{future::LocalBoxFuture, ready}; use futures_core::{future::LocalBoxFuture, ready};
use log::{error, trace}; use log::{error, trace};
@ -54,9 +54,14 @@ impl<T: Address> Service<Connect<T>> for TcpConnector {
fn call(&self, req: Connect<T>) -> Self::Future { fn call(&self, req: Connect<T>) -> Self::Future {
let port = req.port(); let port = req.port();
let Connect { req, addr, .. } = req; let Connect {
req,
addr,
local_addr,
..
} = req;
TcpConnectorResponse::new(req, port, addr) TcpConnectorResponse::new(req, port, local_addr, addr)
} }
} }
@ -65,6 +70,7 @@ pub enum TcpConnectorResponse<T> {
Response { Response {
req: Option<T>, req: Option<T>,
port: u16, port: u16,
local_addr: Option<IpAddr>,
addrs: Option<VecDeque<SocketAddr>>, addrs: Option<VecDeque<SocketAddr>>,
stream: Option<ReusableBoxFuture<Result<TcpStream, io::Error>>>, stream: Option<ReusableBoxFuture<Result<TcpStream, io::Error>>>,
}, },
@ -72,7 +78,12 @@ pub enum TcpConnectorResponse<T> {
} }
impl<T: Address> TcpConnectorResponse<T> { impl<T: Address> TcpConnectorResponse<T> {
pub(crate) fn new(req: T, port: u16, addr: ConnectAddrs) -> TcpConnectorResponse<T> { pub(crate) fn new(
req: T,
port: u16,
local_addr: Option<IpAddr>,
addr: ConnectAddrs,
) -> TcpConnectorResponse<T> {
if addr.is_none() { if addr.is_none() {
error!("TCP connector: unresolved connection address"); error!("TCP connector: unresolved connection address");
return TcpConnectorResponse::Error(Some(ConnectError::Unresolved)); return TcpConnectorResponse::Error(Some(ConnectError::Unresolved));
@ -90,8 +101,9 @@ impl<T: Address> TcpConnectorResponse<T> {
ConnectAddrs::One(addr) => TcpConnectorResponse::Response { ConnectAddrs::One(addr) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
local_addr,
addrs: None, addrs: None,
stream: Some(ReusableBoxFuture::new(TcpStream::connect(addr))), stream: Some(ReusableBoxFuture::new(connect(addr, local_addr))),
}, },
// when resolver returns multiple socket addr for request they would be popped from // when resolver returns multiple socket addr for request they would be popped from
@ -99,6 +111,7 @@ impl<T: Address> TcpConnectorResponse<T> {
ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response { ConnectAddrs::Multi(addrs) => TcpConnectorResponse::Response {
req: Some(req), req: Some(req),
port, port,
local_addr,
addrs: Some(addrs), addrs: Some(addrs),
stream: None, stream: None,
}, },
@ -116,6 +129,7 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
TcpConnectorResponse::Response { TcpConnectorResponse::Response {
req, req,
port, port,
local_addr,
addrs, addrs,
stream, stream,
} => loop { } => loop {
@ -148,11 +162,40 @@ impl<T: Address> Future for TcpConnectorResponse<T> {
// try to connect // try to connect
let addr = addrs.as_mut().unwrap().pop_front().unwrap(); let addr = addrs.as_mut().unwrap().pop_front().unwrap();
let fut = connect(addr, *local_addr);
match stream { match stream {
Some(rbf) => rbf.set(TcpStream::connect(addr)), Some(rbf) => rbf.set(fut),
None => *stream = Some(ReusableBoxFuture::new(TcpStream::connect(addr))), None => *stream = Some(ReusableBoxFuture::new(fut)),
} }
}, },
} }
} }
} }
async fn connect(addr: SocketAddr, local_addr: Option<IpAddr>) -> io::Result<TcpStream> {
// use local addr if connect asks for it.
match local_addr {
Some(ip_addr) => {
let socket = match ip_addr {
IpAddr::V4(ip_addr) => {
let socket = TcpSocket::new_v4()?;
let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0));
socket.bind(addr)?;
socket
}
IpAddr::V6(ip_addr) => {
let socket = TcpSocket::new_v6()?;
let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0));
socket.bind(addr)?;
socket
}
};
socket.set_reuseaddr(true)?;
socket.connect(addr).await
}
None => TcpStream::connect(addr).await,
}
}

View File

@ -1,4 +1,7 @@
use std::io; use std::{
io,
net::{IpAddr, Ipv4Addr},
};
use actix_codec::{BytesCodec, Framed}; use actix_codec::{BytesCodec, Framed};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
@ -125,3 +128,25 @@ async fn test_rustls_uri() {
let con = conn.call(addr.into()).await.unwrap(); let con = conn.call(addr.into()).await.unwrap();
assert_eq!(con.peer_addr().unwrap(), srv.addr()); assert_eq!(con.peer_addr().unwrap(), srv.addr());
} }
#[actix_rt::test]
async fn test_local_addr() {
let srv = TestServer::with(|| {
fn_service(|io: TcpStream| async {
let mut framed = Framed::new(io, BytesCodec);
framed.send(Bytes::from_static(b"test")).await?;
Ok::<_, io::Error>(())
})
});
let conn = actix_connect::default_connector();
let local = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
let (con, _) = conn
.call(Connect::with_addr("10", srv.addr()).set_local_addr(local))
.await
.unwrap()
.into_parts();
assert_eq!(con.local_addr().unwrap().ip(), local)
}