From a02e0dfab60041061cb1a7c7d6fb08e18d85f08e Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 29 Jan 2018 23:01:20 -0800 Subject: [PATCH] initial work on client connector --- Cargo.toml | 2 +- examples/websocket/src/client.rs | 3 +- src/client/connect.rs | 179 +++++++++++++++++++++++++++++++ src/client/mod.rs | 1 + src/lib.rs | 11 ++ src/ws/client.rs | 86 ++++++++------- src/ws/connect.rs | 139 ------------------------ src/ws/mod.rs | 2 - 8 files changed, 238 insertions(+), 185 deletions(-) create mode 100644 src/client/connect.rs delete mode 100644 src/ws/connect.rs diff --git a/Cargo.toml b/Cargo.toml index cdb3bd781..1540fb938 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "0.3.4" +version = "0.4.0" authors = ["Nikolay Kim "] description = "Actix web framework" readme = "README.md" diff --git a/examples/websocket/src/client.rs b/examples/websocket/src/client.rs index 2b114263f..bd44c70b3 100644 --- a/examples/websocket/src/client.rs +++ b/examples/websocket/src/client.rs @@ -12,7 +12,6 @@ use std::time::Duration; use actix::*; use futures::Future; -use tokio_core::net::TcpStream; use actix_web::ws::{Message, WsClientError, WsClient, WsClientWriter}; @@ -54,7 +53,7 @@ fn main() { } -struct ChatClient(WsClientWriter); +struct ChatClient(WsClientWriter); #[derive(Message)] struct ClientCommand(String); diff --git a/src/client/connect.rs b/src/client/connect.rs new file mode 100644 index 000000000..1e2432b01 --- /dev/null +++ b/src/client/connect.rs @@ -0,0 +1,179 @@ +#![allow(unused_imports, dead_code)] +use std::{io, time}; +use std::net::{SocketAddr, Shutdown}; +use std::collections::VecDeque; +use std::time::Duration; + +use actix::{fut, Actor, ActorFuture, Arbiter, ArbiterService, Context, + Handler, Response, ResponseType, Supervised}; +use actix::actors::{Connector, ConnectorError, Connect as ResolveConnect}; + +use http::Uri; +use futures::{Async, Future, Poll}; +use tokio_core::reactor::Timeout; +use tokio_core::net::{TcpStream, TcpStreamNew}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use server::IoStream; + +#[derive(Debug)] +pub struct Connect(pub Uri); + +impl ResponseType for Connect { + type Item = Connection; + type Error = ClientConnectorError; +} + +#[derive(Fail, Debug)] +pub enum ClientConnectorError { + /// Invalid url + #[fail(display="Invalid url")] + InvalidUrl, + + /// SSL feature is not enabled + #[fail(display="SSL is not supported")] + SslIsNotSupported, + + /// Connection error + #[fail(display = "{}", _0)] + Connector(ConnectorError), + + /// Connecting took too long + #[fail(display = "Timeout out while establishing connection")] + Timeout, + + /// Connector has been disconnected + #[fail(display = "Internal error: connector has been disconnected")] + Disconnected, + + /// Connection io error + #[fail(display = "{}", _0)] + IoError(io::Error), +} + +impl From for ClientConnectorError { + fn from(err: ConnectorError) -> ClientConnectorError { + ClientConnectorError::Connector(err) + } +} + +#[derive(Debug, Default)] +pub struct ClientConnector { +} + +impl Actor for ClientConnector { + type Context = Context; +} + +impl Supervised for ClientConnector {} + +impl ArbiterService for ClientConnector {} + +impl Handler for ClientConnector { + type Result = Response; + + fn handle(&mut self, msg: Connect, _: &mut Self::Context) -> Self::Result { + let uri = &msg.0; + + if uri.host().is_none() { + return Self::reply(Err(ClientConnectorError::InvalidUrl)) + } + + let proto = match uri.scheme_part() { + Some(scheme) => match Protocol::from(scheme.as_str()) { + Some(proto) => proto, + None => return Self::reply(Err(ClientConnectorError::InvalidUrl)), + }, + None => return Self::reply(Err(ClientConnectorError::InvalidUrl)), + }; + + let port = uri.port().unwrap_or_else(|| proto.port()); + + Self::async_reply( + Connector::from_registry() + .call(self, ResolveConnect::host_and_port(uri.host().unwrap(), port)) + .map_err(|_, _, _| ClientConnectorError::Disconnected) + .and_then(|res, _, _| match res { + Ok(stream) => fut::ok(Connection{stream: Box::new(stream)}), + Err(err) => fut::err(err.into()) + })) + } +} + +#[derive(PartialEq, Hash, Debug)] +enum Protocol { + Http, + Https, + Ws, + Wss, +} + +impl Protocol { + fn from(s: &str) -> Option { + match s { + "http" => Some(Protocol::Http), + "https" => Some(Protocol::Https), + "ws" => Some(Protocol::Ws), + "wss" => Some(Protocol::Wss), + _ => None, + } + } + + fn port(&self) -> u16 { + match *self { + Protocol::Http | Protocol::Ws => 80, + Protocol::Https | Protocol::Wss => 443 + } + } +} + + +pub struct Connection { + stream: Box, +} + +impl Connection { + pub fn stream(&mut self) -> &mut IoStream { + &mut *self.stream + } +} + +impl IoStream for Connection { + fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + IoStream::shutdown(&mut *self.stream, how) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + IoStream::set_nodelay(&mut *self.stream, nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + IoStream::set_linger(&mut *self.stream, dur) + } +} + +impl io::Read for Connection { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.stream.read(buf) + } +} + +impl AsyncRead for Connection {} + +impl io::Write for Connection { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.stream.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } +} + +impl AsyncWrite for Connection { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.stream.shutdown() + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index de9bfdccb..3bd96f642 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod connect; mod parser; mod request; mod response; diff --git a/src/lib.rs b/src/lib.rs index 4722f8c44..8194793f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,6 +147,17 @@ pub use native_tls::Pkcs12; #[cfg(feature="openssl")] pub use openssl::pkcs12::Pkcs12; +#[cfg(feature="openssl")] +pub(crate) const HAS_OPENSSL: bool = true; +#[cfg(not(feature="openssl"))] +pub(crate) const HAS_OPENSSL: bool = false; + +#[cfg(feature="tls")] +pub(crate) const HAS_TLS: bool = true; +#[cfg(not(feature="tls"))] +pub(crate) const HAS_TLS: bool = false; + + pub mod headers { //! Headers implementation diff --git a/src/ws/client.rs b/src/ws/client.rs index b1a75fa89..83c912f4d 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -1,4 +1,5 @@ //! Http client request +#![allow(unused_imports, dead_code)] use std::{io, str}; use std::rc::Rc; use std::time::Duration; @@ -12,9 +13,11 @@ use http::{HttpTryFrom, StatusCode, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; use sha1::Sha1; use futures::{Async, Future, Poll, Stream}; -// use futures::unsync::oneshot; +use futures::future::{Either, err as FutErr}; use tokio_core::net::TcpStream; +use actix::prelude::*; + use body::Binary; use error::UrlParseError; use server::shared::SharedBytes; @@ -22,14 +25,14 @@ use server::shared::SharedBytes; use server::{utils, IoStream}; use client::{ClientRequest, ClientRequestBuilder, HttpResponseParser, HttpResponseParserError, HttpClientWriter}; +use client::connect::{Connect, Connection, ClientConnector, ClientConnectorError}; use super::Message; use super::proto::{CloseCode, OpCode}; use super::frame::Frame; -use super::connect::{TcpConnector, TcpConnectorError}; -pub type WsClientFuture = - Future, WsClientWriter), Error=WsClientError>; +pub type WsClientFuture = + Future; /// Websockt client error @@ -52,7 +55,7 @@ pub enum WsClientError { #[fail(display="Response parsing error")] ResponseParseError(HttpResponseParserError), #[fail(display="{}", _0)] - Connection(TcpConnectorError), + Connector(ClientConnectorError), #[fail(display="{}", _0)] Io(io::Error), #[fail(display="Disconnected")] @@ -71,9 +74,9 @@ impl From for WsClientError { } } -impl From for WsClientError { - fn from(err: TcpConnectorError) -> WsClientError { - WsClientError::Connection(err) +impl From for WsClientError { + fn from(err: ClientConnectorError) -> WsClientError { + WsClientError::Connector(err) } } @@ -145,7 +148,7 @@ impl WsClient { self } - pub fn connect(&mut self) -> Result>, WsClientError> { + pub fn connect(&mut self) -> Result, WsClientError> { if let Some(e) = self.err.take() { return Err(e) } @@ -178,19 +181,20 @@ impl WsClient { return Err(WsClientError::InvalidUrl); } - let connect = TcpConnector::new( - request.uri().host().unwrap(), - request.uri().port().unwrap_or(80), Duration::from_secs(5)); - + // get connection and start handshake Ok(Box::new( - connect - .from_err() - .and_then(move |stream| WsHandshake::new(stream, request)))) + ClientConnector::from_registry().call_fut(Connect(request.uri().clone())) + .map_err(|_| WsClientError::Disconnected) + .and_then(|res| match res { + Ok(stream) => Either::A(WsHandshake::new(stream, request)), + Err(err) => Either::B(FutErr(err.into())), + }) + )) } } -struct WsInner { - stream: T, +struct WsInner { + conn: Connection, writer: HttpClientWriter, parser: HttpResponseParser, parser_buf: BytesMut, @@ -198,15 +202,15 @@ struct WsInner { error_sent: bool, } -struct WsHandshake { - inner: Option>, +struct WsHandshake { + inner: Option, request: ClientRequest, sent: bool, key: String, } -impl WsHandshake { - fn new(stream: T, mut request: ClientRequest) -> WsHandshake { +impl WsHandshake { + fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake { // Generate a random key for the `Sec-WebSocket-Key` header. // a base64-encoded (see Section 4 of [RFC4648]) value that, // when decoded, is 16 bytes in length (RFC 6455) @@ -218,7 +222,7 @@ impl WsHandshake { HeaderValue::try_from(key.as_str()).unwrap()); let inner = WsInner { - stream: stream, + conn: conn, writer: HttpClientWriter::new(SharedBytes::default()), parser: HttpResponseParser::default(), parser_buf: BytesMut::new(), @@ -235,8 +239,8 @@ impl WsHandshake { } } -impl Future for WsHandshake { - type Item = (WsClientReader, WsClientWriter); +impl Future for WsHandshake { + type Item = (WsClientReader, WsClientWriter); type Error = WsClientError; fn poll(&mut self) -> Poll { @@ -246,11 +250,11 @@ impl Future for WsHandshake { self.sent = true; inner.writer.start(&mut self.request); } - if let Err(err) = inner.writer.poll_completed(&mut inner.stream, false) { + if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) { return Err(err.into()) } - match inner.parser.parse(&mut inner.stream, &mut inner.parser_buf) { + match inner.parser.parse(&mut inner.conn, &mut inner.parser_buf) { Ok(Async::Ready(resp)) => { // verify response if resp.status() != StatusCode::SWITCHING_PROTOCOLS { @@ -311,22 +315,22 @@ impl Future for WsHandshake { } -struct Inner { - inner: WsInner, +struct Inner { + inner: WsInner, } -pub struct WsClientReader { - inner: Rc>> +pub struct WsClientReader { + inner: Rc> } -impl WsClientReader { +impl WsClientReader { #[inline] - fn as_mut(&mut self) -> &mut Inner { + fn as_mut(&mut self) -> &mut Inner { unsafe{ &mut *self.inner.get() } } } -impl Stream for WsClientReader { +impl Stream for WsClientReader { type Item = Message; type Error = WsClientError; @@ -334,7 +338,7 @@ impl Stream for WsClientReader { let inner = self.as_mut(); let mut done = false; - match utils::read_from_io(&mut inner.inner.stream, &mut inner.inner.parser_buf) { + match utils::read_from_io(&mut inner.inner.conn, &mut inner.inner.parser_buf) { Ok(Async::Ready(0)) => { done = true; inner.inner.closed = true; @@ -345,7 +349,7 @@ impl Stream for WsClientReader { } // write - let _ = inner.inner.writer.poll_completed(&mut inner.inner.stream, false); + let _ = inner.inner.writer.poll_completed(&mut inner.inner.conn, false); // read match Frame::parse(&mut inner.inner.parser_buf) { @@ -406,18 +410,18 @@ impl Stream for WsClientReader { } } -pub struct WsClientWriter { - inner: Rc>> +pub struct WsClientWriter { + inner: Rc> } -impl WsClientWriter { +impl WsClientWriter { #[inline] - fn as_mut(&mut self) -> &mut Inner { + fn as_mut(&mut self) -> &mut Inner { unsafe{ &mut *self.inner.get() } } } -impl WsClientWriter { +impl WsClientWriter { /// Write payload #[inline] diff --git a/src/ws/connect.rs b/src/ws/connect.rs deleted file mode 100644 index 4cb7d0339..000000000 --- a/src/ws/connect.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::io; -use std::net::SocketAddr; -use std::collections::VecDeque; -use std::time::Duration; - -use actix::Arbiter; -use trust_dns_resolver::ResolverFuture; -use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; -use trust_dns_resolver::lookup_ip::LookupIpFuture; -use futures::{Async, Future, Poll}; -use tokio_core::reactor::Timeout; -use tokio_core::net::{TcpStream, TcpStreamNew}; - - -#[derive(Fail, Debug)] -pub enum TcpConnectorError { - /// Failed to resolve the hostname - #[fail(display = "Failed resolving hostname: {}", _0)] - Dns(String), - - /// Address is invalid - #[fail(display = "Invalid input: {}", _0)] - InvalidInput(&'static str), - - /// Connecting took too long - #[fail(display = "Timeout out while establishing connection")] - Timeout, - - /// Connection io error - #[fail(display = "{}", _0)] - IoError(io::Error), -} - -pub struct TcpConnector { - lookup: Option, - port: u16, - ips: VecDeque, - error: Option, - timeout: Timeout, - stream: Option, -} - -impl TcpConnector { - - pub fn new>(addr: S, port: u16, timeout: Duration) -> TcpConnector { - // try to parse as a regular SocketAddr first - if let Ok(addr) = addr.as_ref().parse() { - let mut ips = VecDeque::new(); - ips.push_back(addr); - - TcpConnector { - lookup: None, - port: port, - ips: ips, - error: None, - stream: None, - timeout: Timeout::new(timeout, Arbiter::handle()).unwrap() } - } else { - // we need to do dns resolution - let resolve = match ResolverFuture::from_system_conf(Arbiter::handle()) { - Ok(resolve) => resolve, - Err(err) => { - warn!("Can not create system dns resolver: {}", err); - ResolverFuture::new( - ResolverConfig::default(), - ResolverOpts::default(), - Arbiter::handle()) - } - }; - - TcpConnector { - lookup: Some(resolve.lookup_ip(addr.as_ref())), - port: port, - ips: VecDeque::new(), - error: None, - stream: None, - timeout: Timeout::new(timeout, Arbiter::handle()).unwrap() } - } - } -} - -impl Future for TcpConnector { - type Item = TcpStream; - type Error = TcpConnectorError; - - fn poll(&mut self) -> Poll { - if let Some(err) = self.error.take() { - Err(err) - } else { - // timeout - if let Ok(Async::Ready(_)) = self.timeout.poll() { - return Err(TcpConnectorError::Timeout) - } - - // lookip ips - if let Some(mut lookup) = self.lookup.take() { - match lookup.poll() { - Ok(Async::NotReady) => { - self.lookup = Some(lookup); - return Ok(Async::NotReady) - }, - Ok(Async::Ready(ips)) => { - let port = self.port; - let ips = ips.iter().map(|ip| SocketAddr::new(ip, port)); - self.ips.extend(ips); - if self.ips.is_empty() { - return Err(TcpConnectorError::Dns( - "Expect at least one A dns record".to_owned())) - } - }, - Err(err) => return Err(TcpConnectorError::Dns(format!("{}", err))), - } - } - - // connect - loop { - if let Some(mut new) = self.stream.take() { - match new.poll() { - Ok(Async::Ready(sock)) => - return Ok(Async::Ready(sock)), - Ok(Async::NotReady) => { - self.stream = Some(new); - return Ok(Async::NotReady) - }, - Err(err) => { - if self.ips.is_empty() { - return Err(TcpConnectorError::IoError(err)) - } - } - } - } - - // try to connect - let addr = self.ips.pop_front().unwrap(); - self.stream = Some(TcpStream::connect(&addr, Arbiter::handle())); - } - } - } -} diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 0effcd9ac..3f6c895a9 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -61,8 +61,6 @@ mod context; mod mask; mod client; -mod connect; - use self::frame::Frame; use self::proto::{hash_key, OpCode}; pub use self::proto::CloseCode;