diff --git a/Cargo.toml b/Cargo.toml index 07ffafa32..cdb3bd781 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "0.3.3" +version = "0.3.4" authors = ["Nikolay Kim "] description = "Actix web framework" readme = "README.md" @@ -34,28 +34,29 @@ tls = ["native-tls", "tokio-tls"] alpn = ["openssl", "openssl/v102", "openssl/v110", "tokio-openssl"] [dependencies] -log = "0.4" -failure = "0.1" -failure_derive = "0.1" +base64 = "0.9" +bitflags = "1.0" +brotli2 = "^0.3.2" +failure = "0.1.1" +flate2 = "1.0" h2 = "0.1" http = "^0.1.2" httparse = "1.2" http-range = "0.1" -time = "0.1" +libc = "0.2" +log = "0.4" mime = "0.3" mime_guess = "1.8" +num_cpus = "1.0" +percent-encoding = "1.0" +rand = "0.4" regex = "0.2" -sha1 = "0.4" -url = "1.6" -libc = "0.2" serde = "1.0" serde_json = "1.0" -brotli2 = "^0.3.2" -percent-encoding = "1.0" +sha1 = "0.4" smallvec = "0.6" -bitflags = "1.0" -num_cpus = "1.0" -flate2 = "1.0" +time = "0.1" +url = "1.6" cookie = { version="0.10", features=["percent-encode", "secure"] } # io @@ -65,6 +66,7 @@ bytes = "0.4" futures = "0.1" tokio-io = "0.1" tokio-core = "0.1" +trust-dns-resolver = "0.7" # native-tls native-tls = { version="0.1", optional = true } diff --git a/examples/websocket/Cargo.toml b/examples/websocket/Cargo.toml index 626bfdb48..d26e749d6 100644 --- a/examples/websocket/Cargo.toml +++ b/examples/websocket/Cargo.toml @@ -7,8 +7,14 @@ authors = ["Nikolay Kim "] name = "server" path = "src/main.rs" +[[bin]] +name = "client" +path = "src/client.rs" + [dependencies] env_logger = "*" futures = "0.1" -actix = "^0.4.2" -actix-web = { git = "https://github.com/actix/actix-web.git" } +tokio-core = "0.1" +#actix = "^0.4.2" +actix = { git = "https://github.com/actix/actix.git" } +actix-web = { path="../../" } diff --git a/examples/websocket/src/client.rs b/examples/websocket/src/client.rs new file mode 100644 index 000000000..45c3a7126 --- /dev/null +++ b/examples/websocket/src/client.rs @@ -0,0 +1,108 @@ +//! Simple websocket client. + +#![allow(unused_variables)] +extern crate actix; +extern crate actix_web; +extern crate env_logger; +extern crate futures; +extern crate tokio_core; + +use std::{io, thread}; +use std::time::Duration; + +use actix::*; +use futures::Future; +use tokio_core::net::TcpStream; +use actix_web::ws::{client, Message, WsClientError}; + + +fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + let _ = env_logger::init(); + let sys = actix::System::new("ws-example"); + + Arbiter::handle().spawn( + client::WsClient::new("http://127.0.0.1:8080/ws/") + .connect().unwrap() + .map_err(|e| { + println!("Error: {}", e); + () + }) + .map(|(reader, writer)| { + let addr: SyncAddress<_> = ChatClient::create(|ctx| { + ctx.add_stream(reader); + ChatClient(writer) + }); + + // start console loop + thread::spawn(move|| { + loop { + let mut cmd = String::new(); + if io::stdin().read_line(&mut cmd).is_err() { + println!("error"); + return + } + addr.send(ClientCommand(cmd)); + } + }); + + () + }) + ); + + let _ = sys.run(); +} + + +struct ChatClient(client::WsWriter); + +#[derive(Message)] +struct ClientCommand(String); + +impl Actor for ChatClient { + type Context = Context; + + fn started(&mut self, ctx: &mut Context) { + // start heartbeats otherwise server will disconnect after 10 seconds + self.hb(ctx) + } + + fn stopping(&mut self, _: &mut Context) -> bool { + println!("Disconnected"); + + // Stop application on disconnect + Arbiter::system().send(actix::msgs::SystemExit(0)); + + true + } +} + +impl ChatClient { + fn hb(&self, ctx: &mut Context) { + ctx.run_later(Duration::new(1, 0), |act, ctx| { + act.0.ping(""); + act.hb(ctx); + }); + } +} + +/// Handle stdin commands +impl Handler for ChatClient { + type Result = (); + + fn handle(&mut self, msg: ClientCommand, ctx: &mut Context) { + self.0.text(msg.0.as_str()) + } +} + +/// Handle server websocket messages +impl Handler> for ChatClient { + type Result = (); + + fn handle(&mut self, msg: Result, ctx: &mut Context) { + match msg { + Ok(Message::Text(txt)) => println!("Server: {:?}", txt), + _ => () + } + } +} diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs index a6abac908..cea4732ef 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket/src/main.rs @@ -44,7 +44,7 @@ impl Handler for MyWebSocket { } fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=trace"); + ::std::env::set_var("RUST_LOG", "actix_web=info"); let _ = env_logger::init(); let sys = actix::System::new("ws-example"); diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 000000000..7fce930fc --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,5 @@ +mod parser; +mod response; + +pub use self::response::ClientResponse; +pub use self::parser::{HttpResponseParser, HttpResponseParserError}; diff --git a/src/client/parser.rs b/src/client/parser.rs new file mode 100644 index 000000000..d072f69b4 --- /dev/null +++ b/src/client/parser.rs @@ -0,0 +1,239 @@ +use std::mem; +use httparse; +use http::{Version, HttpTryFrom, HeaderMap, StatusCode}; +use http::header::{self, HeaderName, HeaderValue}; +use bytes::BytesMut; +use futures::{Poll, Async}; + +use error::{ParseError, PayloadError}; +use payload::{Payload, PayloadWriter, DEFAULT_BUFFER_SIZE}; + +use server::{utils, IoStream}; +use server::h1::{Decoder, chunked}; +use server::encoding::PayloadType; + +use super::ClientResponse; + +const MAX_BUFFER_SIZE: usize = 131_072; +const MAX_HEADERS: usize = 96; + + +pub struct HttpResponseParser { + payload: Option, +} + +enum Decoding { + Paused, + Ready, + NotReady, +} + +struct PayloadInfo { + tx: PayloadType, + decoder: Decoder, +} + +#[derive(Debug)] +pub enum HttpResponseParserError { + Disconnect, + Payload, + Error(ParseError), +} + +impl HttpResponseParser { + pub fn new() -> HttpResponseParser { + HttpResponseParser { + payload: None, + } + } + + fn decode(&mut self, buf: &mut BytesMut) -> Result { + if let Some(ref mut payload) = self.payload { + if payload.tx.capacity() > DEFAULT_BUFFER_SIZE { + return Ok(Decoding::Paused) + } + loop { + match payload.decoder.decode(buf) { + Ok(Async::Ready(Some(bytes))) => { + payload.tx.feed_data(bytes) + }, + Ok(Async::Ready(None)) => { + payload.tx.feed_eof(); + return Ok(Decoding::Ready) + }, + Ok(Async::NotReady) => return Ok(Decoding::NotReady), + Err(err) => { + payload.tx.set_error(err.into()); + return Err(HttpResponseParserError::Payload) + } + } + } + } else { + return Ok(Decoding::Ready) + } + } + + pub fn parse(&mut self, io: &mut T, buf: &mut BytesMut) + -> Poll + where T: IoStream + { + // read payload + if self.payload.is_some() { + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + if let Some(ref mut payload) = self.payload { + payload.tx.set_error(PayloadError::Incomplete); + } + // http channel should not deal with payload errors + return Err(HttpResponseParserError::Payload) + }, + Err(err) => { + if let Some(ref mut payload) = self.payload { + payload.tx.set_error(err.into()); + } + // http channel should not deal with payload errors + return Err(HttpResponseParserError::Payload) + } + _ => (), + } + match self.decode(buf)? { + Decoding::Ready => self.payload = None, + Decoding::Paused | Decoding::NotReady => return Ok(Async::NotReady), + } + } + + // if buf is empty parse_message will always return NotReady, let's avoid that + let read = if buf.is_empty() { + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + // debug!("Ignored premature client disconnection"); + return Err(HttpResponseParserError::Disconnect); + }, + Ok(Async::Ready(_)) => (), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(err) => + return Err(HttpResponseParserError::Error(err.into())) + } + false + } else { + true + }; + + loop { + match HttpResponseParser::parse_message(buf).map_err(HttpResponseParserError::Error)? { + Async::Ready((msg, decoder)) => { + // process payload + if let Some(payload) = decoder { + self.payload = Some(payload); + match self.decode(buf)? { + Decoding::Paused | Decoding::NotReady => (), + Decoding::Ready => self.payload = None, + } + } + return Ok(Async::Ready(msg)); + }, + Async::NotReady => { + if buf.capacity() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(HttpResponseParserError::Error(ParseError::TooLarge)); + } + if read { + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + debug!("Ignored premature client disconnection"); + return Err(HttpResponseParserError::Disconnect); + }, + Ok(Async::Ready(_)) => (), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(err) => + return Err(HttpResponseParserError::Error(err.into())) + } + } else { + return Ok(Async::NotReady) + } + }, + } + } + } + + fn parse_message(buf: &mut BytesMut) -> Poll<(ClientResponse, Option), ParseError> + { + // Parse http message + let bytes_ptr = buf.as_ref().as_ptr() as usize; + let mut headers: [httparse::Header; MAX_HEADERS] = + unsafe{mem::uninitialized()}; + + let (len, version, status, headers_len) = { + let b = unsafe{ let b: &[u8] = buf; mem::transmute(b) }; + let mut resp = httparse::Response::new(&mut headers); + match resp.parse(b)? { + httparse::Status::Complete(len) => { + let version = if resp.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + let status = StatusCode::from_u16(resp.code.unwrap()) + .map_err(|_| ParseError::Status)?; + + (len, version, status, resp.headers.len()) + } + httparse::Status::Partial => return Ok(Async::NotReady), + } + }; + + + let slice = buf.split_to(len).freeze(); + + // convert headers + let mut hdrs = HeaderMap::new(); + for header in headers[..headers_len].iter() { + if let Ok(name) = HeaderName::try_from(header.name) { + let v_start = header.value.as_ptr() as usize - bytes_ptr; + let v_end = v_start + header.value.len(); + let value = unsafe { + HeaderValue::from_shared_unchecked(slice.slice(v_start, v_end)) }; + hdrs.append(name, value); + } else { + return Err(ParseError::Header) + } + } + + let decoder = if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { + // Content-Length + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + Some(Decoder::length(len)) + } else { + debug!("illegal Content-Length: {:?}", len); + return Err(ParseError::Header) + } + } else { + debug!("illegal Content-Length: {:?}", len); + return Err(ParseError::Header) + } + } else if chunked(&hdrs)? { + // Chunked encoding + Some(Decoder::chunked()) + } else if hdrs.contains_key(header::UPGRADE) { + Some(Decoder::eof()) + } else { + None + }; + + if let Some(decoder) = decoder { + let (psender, payload) = Payload::new(false); + let info = PayloadInfo { + tx: PayloadType::new(&hdrs, psender), + decoder: decoder, + }; + Ok(Async::Ready( + (ClientResponse::new(status, version, hdrs, Some(payload)), Some(info)))) + } else { + Ok(Async::Ready( + (ClientResponse::new(status, version, hdrs, None), None))) + } + } +} diff --git a/src/client/response.rs b/src/client/response.rs new file mode 100644 index 000000000..0bd8ed608 --- /dev/null +++ b/src/client/response.rs @@ -0,0 +1,76 @@ +#![allow(dead_code)] +use std::fmt; +use http::{HeaderMap, StatusCode, Version}; +use http::header::HeaderValue; + +use payload::Payload; + + +pub struct ClientResponse { + /// The response's status + status: StatusCode, + + /// The response's version + version: Version, + + /// The response's headers + headers: HeaderMap, + + payload: Option, +} + +impl ClientResponse { + pub fn new(status: StatusCode, version: Version, + headers: HeaderMap, payload: Option) -> Self { + ClientResponse { + status: status, version: version, headers: headers, payload: payload + } + } + + /// Get the HTTP version of this response. + #[inline] + pub fn version(&self) -> Version { + self.version + } + + /// Get the headers from the response. + #[inline] + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Get a mutable reference to the headers. + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + /// Get the status from the server. + #[inline] + pub fn status(&self) -> StatusCode { + self.status + } + + /// Set the `StatusCode` for this response. + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + &mut self.status + } +} + +impl fmt::Debug for ClientResponse { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = write!( + f, "\nClientResponse {:?} {}\n", self.version, self.status); + let _ = write!(f, " headers:\n"); + for key in self.headers.keys() { + let vals: Vec<_> = self.headers.get_all(key).iter().collect(); + if vals.len() > 1 { + let _ = write!(f, " {:?}: {:?}\n", key, vals); + } else { + let _ = write!(f, " {:?}: {:?}\n", key, vals[0]); + } + } + res + } +} diff --git a/src/error.rs b/src/error.rs index 084249217..56d817d7c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,7 +16,7 @@ use http::{header, StatusCode, Error as HttpError}; use http::uri::InvalidUriBytes; use http_range::HttpRangeParseError; use serde_json::error::Error as JsonError; -use url::ParseError as UrlParseError; +pub use url::ParseError as UrlParseError; // re-exports pub use cookie::{ParseError as CookieParseError}; @@ -106,6 +106,9 @@ default impl ResponseError for T { /// `InternalServerError` for `JsonError` impl ResponseError for JsonError {} +/// `InternalServerError` for `UrlParseError` +impl ResponseError for UrlParseError {} + /// Return `InternalServerError` for `HttpError`, /// Response generation can return `HttpError`, so it is internal error impl ResponseError for HttpError {} diff --git a/src/httpresponse.rs b/src/httpresponse.rs index b9cdb60b3..20607b95b 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -1,17 +1,15 @@ -//! Pieces pertaining to the HTTP response. +//! Http response use std::{mem, str, fmt}; use std::io::Write; use std::cell::RefCell; -use std::convert::Into; use std::collections::VecDeque; -use cookie::CookieJar; +use cookie::{Cookie, CookieJar}; use bytes::{Bytes, BytesMut, BufMut}; use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; use serde_json; use serde::Serialize; -use cookie::Cookie; use body::Body; use error::Error; @@ -261,7 +259,6 @@ impl HttpResponseBuilder { /// } /// fn main() {} /// ``` - #[inline] pub fn header(&mut self, key: K, value: V) -> &mut Self where HeaderName: HttpTryFrom, HeaderValue: HttpTryFrom diff --git a/src/lib.rs b/src/lib.rs index 44d4d1518..5bf8020cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ #[macro_use] extern crate log; extern crate time; +extern crate base64; extern crate bytes; extern crate sha1; extern crate regex; @@ -66,6 +67,7 @@ extern crate httparse; extern crate http_range; extern crate mime; extern crate mime_guess; +extern crate rand; extern crate url; extern crate libc; extern crate serde; @@ -76,6 +78,7 @@ extern crate percent_encoding; extern crate smallvec; extern crate num_cpus; extern crate h2 as http2; +extern crate trust_dns_resolver; #[macro_use] extern crate actix; #[cfg(test)] @@ -106,6 +109,8 @@ mod resource; mod handler; mod pipeline; +mod client; + pub mod fs; pub mod ws; pub mod error; diff --git a/src/server/h1.rs b/src/server/h1.rs index b6629a1f5..b039b09ed 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -550,7 +550,7 @@ impl Reader { } /// Check if request has chunked transfer encoding -fn chunked(headers: &HeaderMap) -> Result { +pub fn chunked(headers: &HeaderMap) -> Result { if let Some(encodings) = headers.get(header::TRANSFER_ENCODING) { if let Ok(s) = encodings.to_str() { Ok(s.to_lowercase().contains("chunked")) @@ -567,7 +567,7 @@ fn chunked(headers: &HeaderMap) -> Result { /// If a message body does not include a Transfer-Encoding, it *should* /// include a Content-Length header. #[derive(Debug, Clone, PartialEq)] -struct Decoder { +pub struct Decoder { kind: Kind, } diff --git a/src/server/mod.rs b/src/server/mod.rs index 869788ddc..39df2fc8d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -9,14 +9,14 @@ use tokio_core::net::TcpStream; mod srv; mod worker; mod channel; -mod encoding; -mod h1; +pub(crate) mod encoding; +pub(crate) mod h1; mod h2; mod h1writer; mod h2writer; mod settings; -mod shared; -mod utils; +pub(crate) mod shared; +pub(crate) mod utils; pub use self::srv::HttpServer; pub use self::settings::ServerSettings; diff --git a/src/ws/client.rs b/src/ws/client.rs new file mode 100644 index 000000000..c4fa762b3 --- /dev/null +++ b/src/ws/client.rs @@ -0,0 +1,561 @@ +//! Http client request +use std::{fmt, io, str}; +use std::rc::Rc; +use std::time::Duration; +use std::cell::UnsafeCell; + +use base64; +use rand; +use cookie::{Cookie, CookieJar}; +use bytes::BytesMut; +use http::{Method, Version, HeaderMap, HttpTryFrom, StatusCode, Error as HttpError}; +use http::header::{self, HeaderName, HeaderValue}; +use url::Url; +use sha1::Sha1; +use futures::{Async, Future, Poll, Stream}; +// use futures::unsync::oneshot; +use tokio_core::net::TcpStream; + +use body::{Body, Binary}; +use error::UrlParseError; +use headers::ContentEncoding; +use server::shared::SharedBytes; + +use server::{utils, IoStream}; +use client::{HttpResponseParser, HttpResponseParserError}; + +use super::Message; +use super::proto::{CloseCode, OpCode}; +use super::frame::Frame; +use super::writer::Writer; +use super::connect::{TcpConnector, TcpConnectorError}; + +/// Websockt client error +#[derive(Fail, Debug)] +pub enum WsClientError { + #[fail(display="Invalid url")] + InvalidUrl, + #[fail(display="Invalid response status")] + InvalidResponseStatus, + #[fail(display="Invalid upgrade header")] + InvalidUpgradeHeader, + #[fail(display="Invalid connection header")] + InvalidConnectionHeader, + #[fail(display="Invalid challenge response")] + InvalidChallengeResponse, + #[fail(display="Http parsing error")] + Http(HttpError), + #[fail(display="Url parsing error")] + Url(UrlParseError), + #[fail(display="Response parsing error")] + ResponseParseError(HttpResponseParserError), + #[fail(display="{}", _0)] + Connection(TcpConnectorError), + #[fail(display="{}", _0)] + Io(io::Error), + #[fail(display="Disconnected")] + Disconnected, +} + +impl From for WsClientError { + fn from(err: HttpError) -> WsClientError { + WsClientError::Http(err) + } +} + +impl From for WsClientError { + fn from(err: UrlParseError) -> WsClientError { + WsClientError::Url(err) + } +} + +impl From for WsClientError { + fn from(err: TcpConnectorError) -> WsClientError { + WsClientError::Connection(err) + } +} + +impl From for WsClientError { + fn from(err: io::Error) -> WsClientError { + WsClientError::Io(err) + } +} + +impl From for WsClientError { + fn from(err: HttpResponseParserError) -> WsClientError { + WsClientError::ResponseParseError(err) + } +} + +type WsFuture = Future, WsWriter), Error=WsClientError>; + +/// Websockt client +pub struct WsClient { + request: Option, + err: Option, + http_err: Option, + cookies: Option, + origin: Option, + protocols: Option, +} + +impl WsClient { + + pub fn new>(url: S) -> WsClient { + let mut cl = WsClient { + request: None, + err: None, + http_err: None, + cookies: None, + origin: None, + protocols: None }; + + match Url::parse(url.as_ref()) { + Ok(url) => { + if url.scheme() != "http" && url.scheme() != "https" && + url.scheme() != "ws" && url.scheme() != "wss" || !url.has_host() { + cl.err = Some(WsClientError::InvalidUrl); + } else { + cl.request = Some(ClientRequest::new(Method::GET, url)); + } + }, + Err(err) => cl.err = Some(err.into()), + } + cl + } + + pub fn protocols(&mut self, protos: U) -> &mut Self + where U: IntoIterator + 'static, + V: AsRef + { + let mut protos = protos.into_iter() + .fold(String::new(), |acc, s| {acc + s.as_ref() + ","}); + protos.pop(); + self.protocols = Some(protos); + self + } + + pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Set request Origin + pub fn origin(&mut self, origin: V) -> &mut Self + where HeaderValue: HttpTryFrom + { + match HeaderValue::try_from(origin) { + Ok(value) => self.origin = Some(value), + Err(e) => self.http_err = Some(e.into()), + } + self + } + + pub fn header(&mut self, key: K, value: V) -> &mut Self + where HeaderName: HttpTryFrom, + HeaderValue: HttpTryFrom + { + if let Some(parts) = parts(&mut self.request, &self.err, &self.http_err) { + match HeaderName::try_from(key) { + Ok(key) => { + match HeaderValue::try_from(value) { + Ok(value) => { parts.headers.append(key, value); } + Err(e) => self.http_err = Some(e.into()), + } + }, + Err(e) => self.http_err = Some(e.into()), + }; + } + self + } + + pub fn connect(&mut self) -> Result>, WsClientError> { + if let Some(e) = self.err.take() { + return Err(e) + } + if let Some(e) = self.http_err.take() { + return Err(e.into()) + } + let mut request = self.request.take().expect("cannot reuse request builder"); + + // headers + if let Some(ref jar) = self.cookies { + for cookie in jar.delta() { + request.headers.append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie.to_string()).map_err(HttpError::from)?); + } + } + + // origin + if let Some(origin) = self.origin.take() { + request.headers.insert(header::ORIGIN, origin); + } + + request.headers.insert(header::UPGRADE, HeaderValue::from_static("websocket")); + request.headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade")); + request.headers.insert( + HeaderName::try_from("SEC-WEBSOCKET-VERSION").unwrap(), + HeaderValue::from_static("13")); + + if let Some(protocols) = self.protocols.take() { + request.headers.insert( + HeaderName::try_from("SEC-WEBSOCKET-PROTOCOL").unwrap(), + HeaderValue::try_from(protocols.as_str()).unwrap()); + } + + let connect = TcpConnector::new( + request.url.host_str().unwrap(), + request.url.port().unwrap_or(80), Duration::from_secs(5)); + + Ok(Box::new( + connect + .from_err() + .and_then(move |stream| WsHandshake::new(stream, request)))) + } +} + +#[inline] +fn parts<'a>(parts: &'a mut Option, + err: &Option, + http_err: &Option) -> Option<&'a mut ClientRequest> +{ + if err.is_some() || http_err.is_some() { + return None + } + parts.as_mut() +} + +pub(crate) struct ClientRequest { + pub url: Url, + pub method: Method, + pub version: Version, + pub headers: HeaderMap, + pub body: Body, + pub chunked: Option, + pub encoding: ContentEncoding, +} + +impl ClientRequest { + + #[inline] + fn new(method: Method, url: Url) -> ClientRequest { + ClientRequest { + url: url, + method: method, + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + body: Body::Empty, + chunked: None, + encoding: ContentEncoding::Auto, + } + } +} + +impl fmt::Debug for ClientRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = write!(f, "\nClientRequest {:?} {}:{}\n", + self.version, self.method, self.url); + let _ = write!(f, " headers:\n"); + for key in self.headers.keys() { + let vals: Vec<_> = self.headers.get_all(key).iter().collect(); + if vals.len() > 1 { + let _ = write!(f, " {:?}: {:?}\n", key, vals); + } else { + let _ = write!(f, " {:?}: {:?}\n", key, vals[0]); + } + } + res + } +} + +struct WsInner { + stream: T, + writer: Writer, + parser: HttpResponseParser, + parser_buf: BytesMut, + closed: bool, + error_sent: bool, +} + +struct WsHandshake { + inner: Option>, + request: ClientRequest, + sent: bool, + key: String, +} + +impl WsHandshake { + fn new(stream: T, mut request: ClientRequest) -> WsHandshake { + // Generate a random key for the `Sec-WebSocket-Key` header. + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let sec_key: [u8; 16] = rand::random(); + let key = base64::encode(&sec_key); + + request.headers.insert( + HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), + HeaderValue::try_from(key.as_str()).unwrap()); + + let inner = WsInner { + stream: stream, + writer: Writer::new(SharedBytes::default()), + parser: HttpResponseParser::new(), + parser_buf: BytesMut::new(), + closed: false, + error_sent: false, + }; + + WsHandshake { + key: key, + inner: Some(inner), + request: request, + sent: false, + } + } +} + +impl Future for WsHandshake { + type Item = (WsReader, WsWriter); + type Error = WsClientError; + + fn poll(&mut self) -> Poll { + let mut inner = self.inner.take().unwrap(); + + if !self.sent { + self.sent = true; + inner.writer.start(&mut self.request); + } + if let Err(err) = inner.writer.poll_completed(&mut inner.stream, false) { + return Err(err.into()) + } + + match inner.parser.parse(&mut inner.stream, &mut inner.parser_buf) { + Ok(Async::Ready(resp)) => { + // verify response + if resp.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus) + } + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + return Err(WsClientError::InvalidUpgradeHeader) + } + // Check for "CONNECTION" header + let has_hdr = if let Some(conn) = resp.headers().get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + s.to_lowercase().contains("upgrade") + } else { false } + } else { false }; + if !has_hdr { + return Err(WsClientError::InvalidConnectionHeader) + } + + let match_key = if let Some(key) = resp.headers().get( + HeaderName::try_from("SEC-WEBSOCKET-ACCEPT").unwrap()) + { + // ... field is constructed by concatenating /key/ ... + // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::new(); + sha1.update(self.key.as_ref()); + sha1.update(WS_GUID); + key.as_bytes() == base64::encode(&sha1.digest().bytes()).as_bytes() + } else { + false + }; + if !match_key { + return Err(WsClientError::InvalidChallengeResponse) + } + + let inner = Rc::new(UnsafeCell::new(Inner{inner: inner})); + Ok(Async::Ready( + (WsReader{inner: Rc::clone(&inner)}, + WsWriter{inner: inner}))) + }, + Ok(Async::NotReady) => { + self.inner = Some(inner); + Ok(Async::NotReady) + }, + Err(err) => Err(err.into()) + } + } +} + + +struct Inner { + inner: WsInner, +} + +pub struct WsReader { + inner: Rc>> +} + +impl WsReader { + #[inline] + fn as_mut(&mut self) -> &mut Inner { + unsafe{ &mut *self.inner.get() } + } +} + +impl Stream for WsReader { + type Item = Message; + type Error = WsClientError; + + fn poll(&mut self) -> Poll, Self::Error> { + let inner = self.as_mut(); + let mut done = false; + + match utils::read_from_io(&mut inner.inner.stream, &mut inner.inner.parser_buf) { + Ok(Async::Ready(0)) => { + done = true; + inner.inner.closed = true; + }, + Ok(Async::Ready(_)) | Ok(Async::NotReady) => (), + Err(err) => + return Err(err.into()) + } + + // write + let _ = inner.inner.writer.poll_completed(&mut inner.inner.stream, false); + + // read + match Frame::parse(&mut inner.inner.parser_buf) { + Ok(Some(frame)) => { + // trace!("WsFrame {}", frame); + let (_finished, opcode, payload) = frame.unpack(); + + match opcode { + OpCode::Continue => unimplemented!(), + OpCode::Bad => + Ok(Async::Ready(Some(Message::Error))), + OpCode::Close => { + inner.inner.closed = true; + inner.inner.error_sent = true; + Ok(Async::Ready(Some(Message::Closed))) + }, + OpCode::Ping => + Ok(Async::Ready(Some( + Message::Ping( + String::from_utf8_lossy(payload.as_ref()).into())))), + OpCode::Pong => + Ok(Async::Ready(Some( + Message::Pong( + String::from_utf8_lossy(payload.as_ref()).into())))), + OpCode::Binary => + Ok(Async::Ready(Some(Message::Binary(payload)))), + OpCode::Text => { + let tmp = Vec::from(payload.as_ref()); + match String::from_utf8(tmp) { + Ok(s) => + Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => + Ok(Async::Ready(Some(Message::Error))), + } + } + } + } + Ok(None) => { + if done { + Ok(Async::Ready(None)) + } else if inner.inner.closed { + if !inner.inner.error_sent { + inner.inner.error_sent = true; + Ok(Async::Ready(Some(Message::Closed))) + } else { + Ok(Async::Ready(None)) + } + } else { + Ok(Async::NotReady) + } + }, + Err(err) => { + inner.inner.closed = true; + inner.inner.error_sent = true; + Err(err.into()) + } + } + } +} + +pub struct WsWriter { + inner: Rc>> +} + +impl WsWriter { + #[inline] + fn as_mut(&mut self) -> &mut Inner { + unsafe{ &mut *self.inner.get() } + } +} + +impl WsWriter { + + /// Write payload + #[inline] + fn write>(&mut self, data: B) { + if !self.as_mut().inner.closed { + let _ = self.as_mut().inner.writer.write(&data.into()); + } else { + warn!("Trying to write to disconnected response"); + } + } + + /// Send text frame + pub fn text(&mut self, text: &str) { + let mut frame = Frame::message(Vec::from(text), OpCode::Text, true); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + self.write(buf); + } + + /// Send binary frame + pub fn binary>(&mut self, data: B) { + let mut frame = Frame::message(data, OpCode::Binary, true); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + self.write(buf); + } + + /// Send ping frame + pub fn ping(&mut self, message: &str) { + let mut frame = Frame::message(Vec::from(message), OpCode::Ping, true); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + self.write(buf); + } + + /// Send pong frame + pub fn pong(&mut self, message: &str) { + let mut frame = Frame::message(Vec::from(message), OpCode::Pong, true); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + self.write(buf); + } + + /// Send close frame + pub fn close(&mut self, code: CloseCode, reason: &str) { + let mut frame = Frame::close(code, reason); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + self.write(buf); + } +} diff --git a/src/ws/connect.rs b/src/ws/connect.rs new file mode 100644 index 000000000..ed664f2e8 --- /dev/null +++ b/src/ws/connect.rs @@ -0,0 +1,141 @@ +use std::io; +use std::net::SocketAddr; +use std::collections::VecDeque; +use std::time::Duration; + +use actix::Arbiter; +use trust_dns_resolver::ResolverFuture; +use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; +use trust_dns_resolver::lookup_ip::LookupIpFuture; +use futures::{Async, Future, Poll}; +use tokio_core::reactor::Timeout; +use tokio_core::net::{TcpStream, TcpStreamNew}; + + +#[derive(Fail, Debug)] +pub enum TcpConnectorError { + /// Failed to resolve the hostname + #[fail(display = "Failed resolving hostname: {}", _0)] + Dns(String), + + /// Address is invalid + #[fail(display = "Invalid input: {}", _0)] + InvalidInput(&'static str), + + /// Connecting took too long + #[fail(display = "Timeout out while establishing connection")] + Timeout, + + /// Connection io error + #[fail(display = "{}", _0)] + IoError(io::Error), +} + +pub struct TcpConnector { + lookup: Option, + port: u16, + ips: VecDeque, + error: Option, + timeout: Timeout, + stream: Option, +} + +impl TcpConnector { + + pub fn new>(addr: S, port: u16, timeout: Duration) -> TcpConnector { + println!("TES: {:?} {:?}", addr.as_ref(), port); + + // try to parse as a regular SocketAddr first + if let Ok(addr) = addr.as_ref().parse() { + let mut ips = VecDeque::new(); + ips.push_back(addr); + + TcpConnector { + lookup: None, + port: port, + ips: ips, + error: None, + stream: None, + timeout: Timeout::new(timeout, Arbiter::handle()).unwrap() } + } else { + // we need to do dns resolution + let resolve = match ResolverFuture::from_system_conf(Arbiter::handle()) { + Ok(resolve) => resolve, + Err(err) => { + warn!("Can not create system dns resolver: {}", err); + ResolverFuture::new( + ResolverConfig::default(), + ResolverOpts::default(), + Arbiter::handle()) + } + }; + + TcpConnector { + lookup: Some(resolve.lookup_ip(addr.as_ref())), + port: port, + ips: VecDeque::new(), + error: None, + stream: None, + timeout: Timeout::new(timeout, Arbiter::handle()).unwrap() } + } + } +} + +impl Future for TcpConnector { + type Item = TcpStream; + type Error = TcpConnectorError; + + fn poll(&mut self) -> Poll { + if let Some(err) = self.error.take() { + Err(err) + } else { + // timeout + if let Ok(Async::Ready(_)) = self.timeout.poll() { + return Err(TcpConnectorError::Timeout) + } + + // lookip ips + if let Some(mut lookup) = self.lookup.take() { + match lookup.poll() { + Ok(Async::NotReady) => { + self.lookup = Some(lookup); + return Ok(Async::NotReady) + }, + Ok(Async::Ready(ips)) => { + let port = self.port; + let ips = ips.iter().map(|ip| SocketAddr::new(ip, port)); + self.ips.extend(ips); + if self.ips.is_empty() { + return Err(TcpConnectorError::Dns( + "Expect at least one A dns record".to_owned())) + } + }, + Err(err) => return Err(TcpConnectorError::Dns(format!("{}", err))), + } + } + + // connect + loop { + if let Some(mut new) = self.stream.take() { + match new.poll() { + Ok(Async::Ready(sock)) => + return Ok(Async::Ready(sock)), + Ok(Async::NotReady) => { + self.stream = Some(new); + return Ok(Async::NotReady) + }, + Err(err) => { + if self.ips.is_empty() { + return Err(TcpConnectorError::IoError(err)) + } + } + } + } + + // try to connect + let addr = self.ips.pop_front().unwrap(); + self.stream = Some(TcpStream::connect(&addr, Arbiter::handle())); + } + } + } +} diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 88935a51c..6ac87ecd7 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -60,11 +60,18 @@ mod proto; mod context; mod mask; +mod connect; +mod writer; + +pub mod client; + use ws::frame::Frame; use ws::proto::{hash_key, OpCode}; pub use ws::proto::CloseCode; pub use ws::context::WebsocketContext; +pub use self::client::{WsClient, WsClientError, WsReader, WsWriter}; + const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION"; diff --git a/src/ws/proto.rs b/src/ws/proto.rs index a1b72f69f..f406969f8 100644 --- a/src/ws/proto.rs +++ b/src/ws/proto.rs @@ -1,7 +1,7 @@ use std::fmt; use std::convert::{Into, From}; use sha1; - +use base64; use self::OpCode::*; /// Operation codes as part of rfc6455. @@ -188,10 +188,7 @@ impl From for CloseCode { } } - static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; -static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - // TODO: hash is always same size, we dont need String pub(crate) fn hash_key(key: &[u8]) -> String { @@ -200,48 +197,7 @@ pub(crate) fn hash_key(key: &[u8]) -> String { hasher.update(key); hasher.update(WS_GUID.as_bytes()); - encode_base64(&hasher.digest().bytes()) -} - - -// This code is based on rustc_serialize base64 STANDARD -fn encode_base64(data: &[u8]) -> String { - let len = data.len(); - let mod_len = len % 3; - - let mut encoded = vec![b'='; (len + 2) / 3 * 4]; - { - let mut in_iter = data[..len - mod_len].iter().map(|&c| u32::from(c)); - let mut out_iter = encoded.iter_mut(); - - let enc = |val| BASE64[val as usize]; - let mut write = |val| *out_iter.next().unwrap() = val; - - while let (Some(one), Some(two), Some(three)) = (in_iter.next(), in_iter.next(), in_iter.next()) { - let g24 = one << 16 | two << 8 | three; - write(enc((g24 >> 18) & 63)); - write(enc((g24 >> 12) & 63)); - write(enc((g24 >> 6 ) & 63)); - write(enc(g24 & 63)); - } - - match mod_len { - 1 => { - let pad = u32::from(data[len-1]) << 16; - write(enc((pad >> 18) & 63)); - write(enc((pad >> 12) & 63)); - } - 2 => { - let pad = u32::from(data[len-2]) << 16 | u32::from(data[len-1]) << 8; - write(enc((pad >> 18) & 63)); - write(enc((pad >> 12) & 63)); - write(enc((pad >> 6) & 63)); - } - _ => (), - } - } - - String::from_utf8(encoded).unwrap() + base64::encode(&hasher.digest().bytes()) } diff --git a/src/ws/writer.rs b/src/ws/writer.rs new file mode 100644 index 000000000..d31f82c50 --- /dev/null +++ b/src/ws/writer.rs @@ -0,0 +1,152 @@ +#![allow(dead_code)] +use std::io; +use bytes::BufMut; +use futures::{Async, Poll}; +use tokio_io::AsyncWrite; +// use http::header::{HeaderValue, CONNECTION, DATE}; + +use body::Binary; +use server::{WriterState, MAX_WRITE_BUFFER_SIZE}; +use server::shared::SharedBytes; + +use super::client::ClientRequest; + + +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific + +bitflags! { + struct Flags: u8 { + const STARTED = 0b0000_0001; + const UPGRADE = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const DISCONNECTED = 0b0000_1000; + } +} + +pub(crate) struct Writer { + flags: Flags, + written: u64, + headers_size: u32, + buffer: SharedBytes, +} + +impl Writer { + + pub fn new(buf: SharedBytes) -> Writer { + Writer { + flags: Flags::empty(), + written: 0, + headers_size: 0, + buffer: buf, + } + } + + pub fn disconnected(&mut self) { + self.buffer.take(); + } + + pub fn keepalive(&self) -> bool { + self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) + } + + fn write_to_stream(&mut self, stream: &mut T) -> io::Result { + while !self.buffer.is_empty() { + match stream.write(self.buffer.as_ref()) { + Ok(0) => { + self.disconnected(); + return Ok(WriterState::Done); + }, + Ok(n) => { + let _ = self.buffer.split_to(n); + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + Err(err) => return Err(err), + } + } + Ok(WriterState::Done) + } +} + +impl Writer { + + pub fn start(&mut self, msg: &mut ClientRequest) { + // prepare task + self.flags.insert(Flags::STARTED); + + // render message + { + let buffer = self.buffer.get_mut(); + buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE); + + // status line + // helpers::write_status_line(version, msg.status().as_u16(), &mut buffer); + // buffer.extend_from_slice(msg.reason().as_bytes()); + buffer.extend_from_slice(b"GET /ws/ HTTP/1.1"); + buffer.extend_from_slice(b"\r\n"); + + // write headers + for (key, value) in &msg.headers { + let v = value.as_ref(); + let k = key.as_str().as_bytes(); + buffer.reserve(k.len() + v.len() + 4); + buffer.put_slice(k); + buffer.put_slice(b": "); + buffer.put_slice(v); + buffer.put_slice(b"\r\n"); + } + + // using helpers::date is quite a lot faster + //if !msg.headers.contains_key(DATE) { + // helpers::date(&mut buffer); + //} else { + // msg eof + buffer.extend_from_slice(b"\r\n"); + //} + self.headers_size = buffer.len() as u32; + } + } + + pub fn write(&mut self, payload: &Binary) -> io::Result { + self.written += payload.len() as u64; + if !self.flags.contains(Flags::DISCONNECTED) { + self.buffer.extend_from_slice(payload.as_ref()) + } + + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + Ok(WriterState::Pause) + } else { + Ok(WriterState::Done) + } + } + + pub fn write_eof(&mut self) -> io::Result { + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + Ok(WriterState::Pause) + } else { + Ok(WriterState::Done) + } + } + + #[inline] + pub fn poll_completed(&mut self, stream: &mut T, shutdown: bool) + -> Poll<(), io::Error> + { + match self.write_to_stream(stream) { + Ok(WriterState::Done) => { + if shutdown { + stream.shutdown() + } else { + Ok(Async::Ready(())) + } + }, + Ok(WriterState::Pause) => Ok(Async::NotReady), + Err(err) => Err(err) + } + } +}