diff --git a/Cargo.toml b/Cargo.toml index 9193aeeaa..87bd05b8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,10 @@ time = "0.1" encoding = "0.2" lazy_static = "1.0" serde_urlencoded = "0.5.3" + cookie = { version="0.11", features=["percent-encode"] } +percent-encoding = "1.0" +url = { version="1.7", features=["query_encoding"] } # io net2 = "0.2" diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 000000000..a5582fea8 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,6 @@ +//! Http client api +mod request; +mod response; + +pub use self::request::{ClientRequest, ClientRequestBuilder}; +pub use self::response::ClientResponse; diff --git a/src/client/request.rs b/src/client/request.rs new file mode 100644 index 000000000..bf82927e9 --- /dev/null +++ b/src/client/request.rs @@ -0,0 +1,564 @@ +use std::fmt; +use std::fmt::Write as FmtWrite; +use std::io::Write; + +use bytes::{BufMut, BytesMut}; +use cookie::{Cookie, CookieJar}; +use percent_encoding::{percent_encode, USERINFO_ENCODE_SET}; +use urlcrate::Url; + +use header::{self, Header, IntoHeaderValue}; +use http::{ + uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method, + Uri, Version, +}; + +/// An HTTP Client Request +/// +/// ```rust +/// # extern crate actix_web; +/// # extern crate futures; +/// # extern crate tokio; +/// # use futures::Future; +/// # use std::process; +/// use actix_web::{actix, client}; +/// +/// fn main() { +/// actix::run( +/// || client::ClientRequest::get("http://www.rust-lang.org") // <- Create request builder +/// .header("User-Agent", "Actix-web") +/// .finish().unwrap() +/// .send() // <- Send http request +/// .map_err(|_| ()) +/// .and_then(|response| { // <- server http response +/// println!("Response: {:?}", response); +/// # actix::System::current().stop(); +/// Ok(()) +/// }), +/// ); +/// } +/// ``` +pub struct ClientRequest { + uri: Uri, + method: Method, + version: Version, + headers: HeaderMap, + chunked: bool, + upgrade: bool, +} + +impl Default for ClientRequest { + fn default() -> ClientRequest { + ClientRequest { + uri: Uri::default(), + method: Method::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + chunked: false, + upgrade: false, + } + } +} + +impl ClientRequest { + /// Create request builder for `GET` request + pub fn get>(uri: U) -> ClientRequestBuilder { + let mut builder = ClientRequest::build(); + builder.method(Method::GET).uri(uri); + builder + } + + /// Create request builder for `HEAD` request + pub fn head>(uri: U) -> ClientRequestBuilder { + let mut builder = ClientRequest::build(); + builder.method(Method::HEAD).uri(uri); + builder + } + + /// Create request builder for `POST` request + pub fn post>(uri: U) -> ClientRequestBuilder { + let mut builder = ClientRequest::build(); + builder.method(Method::POST).uri(uri); + builder + } + + /// Create request builder for `PUT` request + pub fn put>(uri: U) -> ClientRequestBuilder { + let mut builder = ClientRequest::build(); + builder.method(Method::PUT).uri(uri); + builder + } + + /// Create request builder for `DELETE` request + pub fn delete>(uri: U) -> ClientRequestBuilder { + let mut builder = ClientRequest::build(); + builder.method(Method::DELETE).uri(uri); + builder + } +} + +impl ClientRequest { + /// Create client request builder + pub fn build() -> ClientRequestBuilder { + ClientRequestBuilder { + request: Some(ClientRequest::default()), + err: None, + cookies: None, + default_headers: true, + } + } + + /// Get the request URI + #[inline] + pub fn uri(&self) -> &Uri { + &self.uri + } + + /// Set client request URI + #[inline] + pub fn set_uri(&mut self, uri: Uri) { + self.uri = uri + } + + /// Get the request method + #[inline] + pub fn method(&self) -> &Method { + &self.method + } + + /// Set HTTP `Method` for the request + #[inline] + pub fn set_method(&mut self, method: Method) { + self.method = method + } + + /// Get HTTP version for the request + #[inline] + pub fn version(&self) -> Version { + self.version + } + + /// Set http `Version` for the request + #[inline] + pub fn set_version(&mut self, version: Version) { + self.version = version + } + + /// Get the headers from the request + #[inline] + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Get a mutable reference to the headers + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + /// is chunked encoding enabled + #[inline] + pub fn chunked(&self) -> bool { + self.chunked + } + + /// is upgrade request + #[inline] + pub fn upgrade(&self) -> bool { + self.upgrade + } +} + +impl fmt::Debug for ClientRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nClientRequest {:?} {}:{}", + self.version, self.method, self.uri + )?; + writeln!(f, " headers:")?; + for (key, val) in self.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +/// An HTTP Client request builder +/// +/// This type can be used to construct an instance of `ClientRequest` through a +/// builder-like pattern. +pub struct ClientRequestBuilder { + request: Option, + err: Option, + cookies: Option, + default_headers: bool, +} + +impl ClientRequestBuilder { + /// Set HTTP URI of request. + #[inline] + pub fn uri>(&mut self, uri: U) -> &mut Self { + match Url::parse(uri.as_ref()) { + Ok(url) => self._uri(url.as_str()), + Err(_) => self._uri(uri.as_ref()), + } + } + + fn _uri(&mut self, url: &str) -> &mut Self { + match Uri::try_from(url) { + Ok(uri) => { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.uri = uri; + } + } + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set HTTP method of this request. + #[inline] + pub fn method(&mut self, method: Method) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.method = method; + } + self + } + + /// Set HTTP method of this request. + #[inline] + pub fn get_method(&mut self) -> &Method { + let parts = self.request.as_ref().expect("cannot reuse request builder"); + &parts.method + } + + /// Set HTTP version of this request. + /// + /// By default requests's HTTP version depends on network stream + #[inline] + pub fn version(&mut self, version: Version) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.version = version; + } + self + } + + /// Set a header. + /// + /// ```rust + /// # extern crate mime; + /// # extern crate actix_web; + /// # use actix_web::client::*; + /// # + /// use actix_web::{client, http}; + /// + /// fn main() { + /// let req = client::ClientRequest::build() + /// .set(http::header::Date::now()) + /// .set(http::header::ContentType(mime::TEXT_HTML)) + /// .finish() + /// .unwrap(); + /// } + /// ``` + #[doc(hidden)] + pub fn set(&mut self, hdr: H) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + match hdr.try_into() { + Ok(value) => { + parts.headers.insert(H::name(), value); + } + Err(e) => self.err = Some(e.into()), + } + } + self + } + + /// Append a header. + /// + /// Header gets appended to existing header. + /// To override header use `set_header()` method. + /// + /// ```rust + /// # extern crate http; + /// # extern crate actix_web; + /// # use actix_web::client::*; + /// # + /// use http::header; + /// + /// fn main() { + /// let req = ClientRequest::build() + /// .header("X-TEST", "value") + /// .header(header::CONTENT_TYPE, "application/json") + /// .finish() + /// .unwrap(); + /// } + /// ``` + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.headers.append(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set a header. + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set a header only if it is not yet set. + pub fn set_header_if_none(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => if !parts.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + } + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Enable connection upgrade + #[inline] + pub fn upgrade(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.upgrade = true; + } + self + } + + /// Set request's content type + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where + HeaderValue: HttpTryFrom, + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderValue::try_from(value) { + Ok(value) => { + parts.headers.insert(header::CONTENT_TYPE, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set content length + #[inline] + pub fn content_length(&mut self, len: u64) -> &mut Self { + let mut wrt = BytesMut::new().writer(); + let _ = write!(wrt, "{}", len); + self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + } + + /// Set a cookie + /// + /// ```rust + /// # extern crate actix_web; + /// use actix_web::{client, http}; + /// + /// fn main() { + /// let req = client::ClientRequest::build() + /// .cookie( + /// http::Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish(), + /// ) + /// .finish() + /// .unwrap(); + /// } + /// ``` + pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Do not add default request headers. + /// By default `Accept-Encoding` and `User-Agent` headers are set. + pub fn no_default_headers(&mut self) -> &mut Self { + self.default_headers = false; + self + } + + /// This method calls provided closure with builder reference if + /// value is `true`. + pub fn if_true(&mut self, value: bool, f: F) -> &mut Self + where + F: FnOnce(&mut ClientRequestBuilder), + { + if value { + f(self); + } + self + } + + /// This method calls provided closure with builder reference if + /// value is `Some`. + pub fn if_some(&mut self, value: Option, f: F) -> &mut Self + where + F: FnOnce(T, &mut ClientRequestBuilder), + { + if let Some(val) = value { + f(val, self); + } + self + } + + /// Set a body and generate `ClientRequest`. + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn finish(&mut self) -> Result { + if let Some(e) = self.err.take() { + return Err(e); + } + + if self.default_headers { + // enable br only for https + let https = if let Some(parts) = parts(&mut self.request, &self.err) { + parts + .uri + .scheme_part() + .map(|s| s == &uri::Scheme::HTTPS) + .unwrap_or(true) + } else { + true + }; + + if https { + self.set_header_if_none(header::ACCEPT_ENCODING, "br, gzip, deflate"); + } else { + self.set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate"); + } + + // set request host header + if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(host) = parts.uri.host() { + if !parts.headers.contains_key(header::HOST) { + let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); + + let _ = match parts.uri.port() { + None | Some(80) | Some(443) => write!(wrt, "{}", host), + Some(port) => write!(wrt, "{}:{}", host, port), + }; + + match wrt.get_mut().take().freeze().try_into() { + Ok(value) => { + parts.headers.insert(header::HOST, value); + } + Err(e) => self.err = Some(e.into()), + } + } + } + } + + // user agent + self.set_header_if_none( + header::USER_AGENT, + concat!("actix-http/", env!("CARGO_PKG_VERSION")), + ); + } + + let mut request = self.request.take().expect("cannot reuse request builder"); + + // set cookies + if let Some(ref mut jar) = self.cookies { + let mut cookie = String::new(); + for c in jar.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO_ENCODE_SET); + let value = percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + request.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + Ok(request) + } + + /// This method construct new `ClientRequestBuilder` + pub fn take(&mut self) -> ClientRequestBuilder { + ClientRequestBuilder { + request: self.request.take(), + err: self.err.take(), + cookies: self.cookies.take(), + default_headers: self.default_headers, + } + } +} + +#[inline] +fn parts<'a>( + parts: &'a mut Option, err: &Option, +) -> Option<&'a mut ClientRequest> { + if err.is_some() { + return None; + } + parts.as_mut() +} + +impl fmt::Debug for ClientRequestBuilder { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(ref parts) = self.request { + writeln!( + f, + "\nClientRequestBuilder {:?} {}:{}", + parts.version, parts.method, parts.uri + )?; + writeln!(f, " headers:")?; + for (key, val) in parts.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } else { + write!(f, "ClientRequestBuilder(Consumed)") + } + } +} diff --git a/src/client/response.rs b/src/client/response.rs new file mode 100644 index 000000000..627a1c78e --- /dev/null +++ b/src/client/response.rs @@ -0,0 +1,128 @@ +use std::cell::{Cell, Ref, RefCell, RefMut}; +use std::fmt; +use std::rc::Rc; + +use http::{HeaderMap, Method, StatusCode, Version}; + +use extensions::Extensions; +use httpmessage::HttpMessage; +use payload::Payload; +use request::{Message, MessageFlags, MessagePool}; +use uri::Url; + +/// Client Response +pub struct ClientResponse { + pub(crate) inner: Rc, +} + +impl HttpMessage for ClientResponse { + type Stream = Payload; + + fn headers(&self) -> &HeaderMap { + &self.inner.headers + } + + #[inline] + fn payload(&self) -> Payload { + if let Some(payload) = self.inner.payload.borrow_mut().take() { + payload + } else { + Payload::empty() + } + } +} + +impl ClientResponse { + /// Create new Request instance + pub fn new() -> ClientResponse { + ClientResponse::with_pool(MessagePool::pool()) + } + + /// Create new Request instance with pool + pub(crate) fn with_pool(pool: &'static MessagePool) -> ClientResponse { + ClientResponse { + inner: Rc::new(Message { + pool, + method: Method::GET, + status: StatusCode::OK, + url: Url::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + flags: Cell::new(MessageFlags::empty()), + payload: RefCell::new(None), + extensions: RefCell::new(Extensions::new()), + }), + } + } + + #[inline] + pub(crate) fn inner(&self) -> &Message { + self.inner.as_ref() + } + + #[inline] + pub(crate) fn inner_mut(&mut self) -> &mut Message { + Rc::get_mut(&mut self.inner).expect("Multiple copies exist") + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.inner().version + } + + /// Get the status from the server. + #[inline] + pub fn status(&self) -> StatusCode { + self.inner().status + } + + #[inline] + /// Returns Request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.inner().headers + } + + #[inline] + /// Returns mutable Request's headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.inner_mut().headers + } + + /// Checks if a connection should be kept alive. + #[inline] + pub fn keep_alive(&self) -> bool { + self.inner().flags.get().contains(MessageFlags::KEEPALIVE) + } + + /// Request extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.inner().extensions.borrow() + } + + /// Mutable reference to a the request's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.inner().extensions.borrow_mut() + } +} + +impl Drop for ClientResponse { + fn drop(&mut self) { + if Rc::strong_count(&self.inner) == 1 { + self.inner.pool.release(self.inner.clone()); + } + } +} + +impl fmt::Debug for ClientResponse { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} diff --git a/src/h1/client.rs b/src/h1/client.rs new file mode 100644 index 000000000..8ec1f0aea --- /dev/null +++ b/src/h1/client.rs @@ -0,0 +1,217 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::io::{self, Write}; + +use bytes::{BufMut, Bytes, BytesMut}; +use tokio_codec::{Decoder, Encoder}; + +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, ResponseDecoder}; +use super::encoder::{RequestEncoder, ResponseLength}; +use super::{Message, MessageType}; +use body::{Binary, Body}; +use client::{ClientRequest, ClientResponse}; +use config::ServiceConfig; +use error::ParseError; +use helpers; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use http::{Method, Version}; +use request::MessagePool; + +bitflags! { + struct Flags: u8 { + const HEAD = 0b0000_0001; + const UPGRADE = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const KEEPALIVE_ENABLED = 0b0000_1000; + const UNHANDLED = 0b0001_0000; + } +} + +const AVERAGE_HEADER_SIZE: usize = 30; + +/// HTTP/1 Codec +pub struct ClientCodec { + config: ServiceConfig, + decoder: ResponseDecoder, + payload: Option, + version: Version, + + // encoder part + flags: Flags, + headers_size: u32, + te: RequestEncoder, +} + +impl Default for ClientCodec { + fn default() -> Self { + ClientCodec::new(ServiceConfig::default()) + } +} + +impl ClientCodec { + /// Create HTTP/1 codec. + /// + /// `keepalive_enabled` how response `connection` header get generated. + pub fn new(config: ServiceConfig) -> Self { + ClientCodec::with_pool(MessagePool::pool(), config) + } + + /// Create HTTP/1 codec with request's pool + pub(crate) fn with_pool(pool: &'static MessagePool, config: ServiceConfig) -> Self { + let flags = if config.keep_alive_enabled() { + Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + ClientCodec { + config, + decoder: ResponseDecoder::with_pool(pool), + payload: None, + version: Version::HTTP_11, + + flags, + headers_size: 0, + te: RequestEncoder::default(), + } + } + + /// Check if request is upgrade + pub fn upgrade(&self) -> bool { + self.flags.contains(Flags::UPGRADE) + } + + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.flags.contains(Flags::KEEPALIVE) + } + + /// Check last request's message type + pub fn message_type(&self) -> MessageType { + if self.flags.contains(Flags::UNHANDLED) { + MessageType::Unhandled + } else if self.payload.is_none() { + MessageType::None + } else { + MessageType::Payload + } + } + + /// prepare transfer encoding + pub fn prepare_te(&mut self, res: &mut ClientRequest) { + self.te + .update(res, self.flags.contains(Flags::HEAD), self.version); + } + + fn encode_response( + &mut self, msg: ClientRequest, buffer: &mut BytesMut, + ) -> io::Result<()> { + // Connection upgrade + if msg.upgrade() { + self.flags.insert(Flags::UPGRADE); + } + + // render message + { + // status line + writeln!( + Writer(buffer), + "{} {} {:?}\r", + msg.method(), + msg.uri() + .path_and_query() + .map(|u| u.as_str()) + .unwrap_or("/"), + msg.version() + ).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + // write headers + buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE); + for (key, value) in msg.headers() { + let v = value.as_ref(); + let k = key.as_str().as_bytes(); + buffer.reserve(k.len() + v.len() + 4); + buffer.put_slice(k); + buffer.put_slice(b": "); + buffer.put_slice(v); + buffer.put_slice(b"\r\n"); + } + + // set date header + if !msg.headers().contains_key(DATE) { + self.config.set_date(buffer); + } else { + buffer.extend_from_slice(b"\r\n"); + } + } + + Ok(()) + } +} + +impl Decoder for ClientCodec { + type Item = Message; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if self.payload.is_some() { + Ok(match self.payload.as_mut().unwrap().decode(src)? { + Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), + Some(PayloadItem::Eof) => Some(Message::Chunk(None)), + None => None, + }) + } else if self.flags.contains(Flags::UNHANDLED) { + Ok(None) + } else if let Some((req, payload)) = self.decoder.decode(src)? { + self.flags + .set(Flags::HEAD, req.inner.method == Method::HEAD); + self.version = req.inner.version; + if self.flags.contains(Flags::KEEPALIVE_ENABLED) { + self.flags.set(Flags::KEEPALIVE, req.keep_alive()); + } + match payload { + PayloadType::None => self.payload = None, + PayloadType::Payload(pl) => self.payload = Some(pl), + PayloadType::Unhandled => { + self.payload = None; + self.flags.insert(Flags::UNHANDLED); + } + }; + Ok(Some(Message::Item(req))) + } else { + Ok(None) + } + } +} + +impl Encoder for ClientCodec { + type Item = Message; + type Error = io::Error; + + fn encode( + &mut self, item: Self::Item, dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + Message::Item(res) => { + self.encode_response(res, dst)?; + } + Message::Chunk(Some(bytes)) => { + self.te.encode(bytes.as_ref(), dst)?; + } + Message::Chunk(None) => { + self.te.encode_eof(dst)?; + } + } + Ok(()) + } +} + +pub struct Writer<'a>(pub &'a mut BytesMut); + +impl<'a> io::Write for Writer<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/src/h1/codec.rs b/src/h1/codec.rs index 020466482..8f7e77e05 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -4,15 +4,16 @@ use std::io::{self, Write}; use bytes::{BufMut, Bytes, BytesMut}; use tokio_codec::{Decoder, Encoder}; -use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder, RequestPayloadType}; +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, RequestDecoder}; use super::encoder::{ResponseEncoder, ResponseLength}; +use super::{Message, MessageType}; use body::{Binary, Body}; use config::ServiceConfig; use error::ParseError; use helpers; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::{Method, Version}; -use request::{Request, RequestPool}; +use request::{MessagePool, Request}; use response::Response; bitflags! { @@ -27,32 +28,6 @@ bitflags! { const AVERAGE_HEADER_SIZE: usize = 30; -#[derive(Debug)] -/// Http response -pub enum OutMessage { - /// Http response message - Response(Response), - /// Payload chunk - Chunk(Option), -} - -/// Incoming http/1 request -#[derive(Debug)] -pub enum InMessage { - /// Request - Message(Request, InMessageType), - /// Payload chunk - Chunk(Option), -} - -/// Incoming request type -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum InMessageType { - None, - Payload, - Unhandled, -} - /// HTTP/1 Codec pub struct Codec { config: ServiceConfig, @@ -71,11 +46,11 @@ impl Codec { /// /// `keepalive_enabled` how response `connection` header get generated. pub fn new(config: ServiceConfig) -> Self { - Codec::with_pool(RequestPool::pool(), config) + Codec::with_pool(MessagePool::pool(), config) } /// Create HTTP/1 codec with request's pool - pub(crate) fn with_pool(pool: &'static RequestPool, config: ServiceConfig) -> Self { + pub(crate) fn with_pool(pool: &'static MessagePool, config: ServiceConfig) -> Self { let flags = if config.keep_alive_enabled() { Flags::KEEPALIVE_ENABLED } else { @@ -104,13 +79,13 @@ impl Codec { } /// Check last request's message type - pub fn message_type(&self) -> InMessageType { + pub fn message_type(&self) -> MessageType { if self.flags.contains(Flags::UNHANDLED) { - InMessageType::Unhandled + MessageType::Unhandled } else if self.payload.is_none() { - InMessageType::None + MessageType::None } else { - InMessageType::Payload + MessageType::Payload } } @@ -256,14 +231,14 @@ impl Codec { } impl Decoder for Codec { - type Item = InMessage; + type Item = Message; type Error = ParseError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { if self.payload.is_some() { Ok(match self.payload.as_mut().unwrap().decode(src)? { - Some(PayloadItem::Chunk(chunk)) => Some(InMessage::Chunk(Some(chunk))), - Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)), + Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), + Some(PayloadItem::Eof) => Some(Message::Chunk(None)), None => None, }) } else if self.flags.contains(Flags::UNHANDLED) { @@ -275,22 +250,15 @@ impl Decoder for Codec { if self.flags.contains(Flags::KEEPALIVE_ENABLED) { self.flags.set(Flags::KEEPALIVE, req.keep_alive()); } - let payload = match payload { - RequestPayloadType::None => { - self.payload = None; - InMessageType::None - } - RequestPayloadType::Payload(pl) => { - self.payload = Some(pl); - InMessageType::Payload - } - RequestPayloadType::Unhandled => { + match payload { + PayloadType::None => self.payload = None, + PayloadType::Payload(pl) => self.payload = Some(pl), + PayloadType::Unhandled => { self.payload = None; self.flags.insert(Flags::UNHANDLED); - InMessageType::Unhandled } - }; - Ok(Some(InMessage::Message(req, payload))) + } + Ok(Some(Message::Item(req))) } else { Ok(None) } @@ -298,20 +266,20 @@ impl Decoder for Codec { } impl Encoder for Codec { - type Item = OutMessage; + type Item = Message; type Error = io::Error; fn encode( &mut self, item: Self::Item, dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - OutMessage::Response(res) => { + Message::Item(res) => { self.encode_response(res, dst)?; } - OutMessage::Chunk(Some(bytes)) => { + Message::Chunk(Some(bytes)) => { self.te.encode(bytes.as_ref(), dst)?; } - OutMessage::Chunk(None) => { + Message::Chunk(None) => { self.te.encode_eof(dst)?; } } diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index c2a8d0e99..2c1e6c1f2 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -5,38 +5,43 @@ use futures::{Async, Poll}; use httparse; use tokio_codec::Decoder; +use client::ClientResponse; use error::ParseError; use http::header::{HeaderName, HeaderValue}; -use http::{header, HttpTryFrom, Method, Uri, Version}; -use request::{MessageFlags, Request, RequestPool}; +use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; +use request::{MessageFlags, MessagePool, Request}; use uri::Url; const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; -pub struct RequestDecoder(&'static RequestPool); +/// Client request decoder +pub struct RequestDecoder(&'static MessagePool); + +/// Server response decoder +pub struct ResponseDecoder(&'static MessagePool); /// Incoming request type -pub enum RequestPayloadType { +pub enum PayloadType { None, Payload(PayloadDecoder), Unhandled, } impl RequestDecoder { - pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder { + pub(crate) fn with_pool(pool: &'static MessagePool) -> RequestDecoder { RequestDecoder(pool) } } impl Default for RequestDecoder { fn default() -> RequestDecoder { - RequestDecoder::with_pool(RequestPool::pool()) + RequestDecoder::with_pool(MessagePool::pool()) } } impl Decoder for RequestDecoder { - type Item = (Request, RequestPayloadType); + type Item = (Request, PayloadType); type Error = ParseError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { @@ -77,7 +82,7 @@ impl Decoder for RequestDecoder { let slice = src.split_to(len).freeze(); // convert headers - let mut msg = RequestPool::get(self.0); + let mut msg = MessagePool::get_request(self.0); { let inner = msg.inner_mut(); inner @@ -156,18 +161,165 @@ impl Decoder for RequestDecoder { // https://tools.ietf.org/html/rfc7230#section-3.3.3 let decoder = if chunked { // Chunked encoding - RequestPayloadType::Payload(PayloadDecoder::chunked()) + PayloadType::Payload(PayloadDecoder::chunked()) } else if let Some(len) = content_length { // Content-Length - RequestPayloadType::Payload(PayloadDecoder::length(len)) + PayloadType::Payload(PayloadDecoder::length(len)) } else if has_upgrade || msg.inner.method == Method::CONNECT { // upgrade(websocket) or connect - RequestPayloadType::Unhandled + PayloadType::Unhandled } else if src.len() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ParseError::TooLarge); } else { - RequestPayloadType::None + PayloadType::None + }; + + Ok(Some((msg, decoder))) + } +} + +impl ResponseDecoder { + pub(crate) fn with_pool(pool: &'static MessagePool) -> ResponseDecoder { + ResponseDecoder(pool) + } +} + +impl Default for ResponseDecoder { + fn default() -> ResponseDecoder { + ResponseDecoder::with_pool(MessagePool::pool()) + } +} + +impl Decoder for ResponseDecoder { + type Item = (ClientResponse, PayloadType); + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + // Parse http message + let mut chunked = false; + let mut content_length = None; + + let msg = { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = + unsafe { mem::uninitialized() }; + + let (len, version, status, headers_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = + unsafe { mem::uninitialized() }; + + let mut res = httparse::Response::new(&mut parsed); + match res.parse(src)? { + httparse::Status::Complete(len) => { + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + let status = StatusCode::from_u16(res.code.unwrap()) + .map_err(|_| ParseError::Status)?; + HeaderIndex::record(src, res.headers, &mut headers); + + (len, version, status, res.headers.len()) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let slice = src.split_to(len).freeze(); + + // convert headers + let mut msg = MessagePool::get_response(self.0); + { + let inner = msg.inner_mut(); + inner + .flags + .get_mut() + .set(MessageFlags::KEEPALIVE, version != Version::HTTP_10); + + for idx in headers[..headers_len].iter() { + if let Ok(name) = + HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]) + { + // Unsafe: httparse check header value for valid utf-8 + let value = unsafe { + HeaderValue::from_shared_unchecked( + slice.slice(idx.value.0, idx.value.1), + ) + }; + match name { + header::CONTENT_LENGTH => { + if let Ok(s) = value.to_str() { + if let Ok(len) = s.parse::() { + content_length = Some(len); + } else { + debug!("illegal Content-Length: {:?}", len); + return Err(ParseError::Header); + } + } else { + debug!("illegal Content-Length: {:?}", len); + return Err(ParseError::Header); + } + } + // transfer-encoding + header::TRANSFER_ENCODING => { + if let Ok(s) = value.to_str() { + chunked = s.to_lowercase().contains("chunked"); + } else { + return Err(ParseError::Header); + } + } + // connection keep-alive state + header::CONNECTION => { + let ka = if let Ok(conn) = value.to_str() { + if version == Version::HTTP_10 + && conn.contains("keep-alive") + { + true + } else { + version == Version::HTTP_11 && !(conn + .contains("close") + || conn.contains("upgrade")) + } + } else { + false + }; + inner.flags.get_mut().set(MessageFlags::KEEPALIVE, ka); + } + _ => (), + } + + inner.headers.append(name, value); + } else { + return Err(ParseError::Header); + } + } + + inner.status = status; + inner.version = version; + } + msg + }; + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + let decoder = if chunked { + // Chunked encoding + PayloadType::Payload(PayloadDecoder::chunked()) + } else if let Some(len) = content_length { + // Content-Length + PayloadType::Payload(PayloadDecoder::length(len)) + } else if msg.inner.status == StatusCode::SWITCHING_PROTOCOLS + || msg.inner.method == Method::CONNECT + { + // switching protocol or connect + PayloadType::Unhandled + } else if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None }; Ok(Some((msg, decoder))) @@ -488,36 +640,30 @@ mod tests { use super::*; use error::ParseError; - use h1::{InMessage, InMessageType}; + use h1::Message; use httpmessage::HttpMessage; use request::Request; - impl RequestPayloadType { + impl PayloadType { fn unwrap(self) -> PayloadDecoder { match self { - RequestPayloadType::Payload(pl) => pl, + PayloadType::Payload(pl) => pl, _ => panic!(), } } fn is_unhandled(&self) -> bool { match self { - RequestPayloadType::Unhandled => true, + PayloadType::Unhandled => true, _ => false, } } } - impl InMessage { + impl Message { fn message(self) -> Request { match self { - InMessage::Message(req, _) => req, - _ => panic!("error"), - } - } - fn is_payload(&self) -> bool { - match *self { - InMessage::Message(_, payload) => payload == InMessageType::Payload, + Message::Item(req) => req, _ => panic!("error"), } } diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 8ae2ae8ce..a7f97e8c9 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -18,8 +18,8 @@ use error::DispatchError; use request::Request; use response::Response; -use super::codec::{Codec, InMessage, InMessageType, OutMessage}; -use super::H1ServiceResult; +use super::codec::Codec; +use super::{H1ServiceResult, Message, MessageType}; const MAX_PIPELINED_MESSAGES: usize = 16; @@ -48,14 +48,14 @@ where state: State, payload: Option, - messages: VecDeque, + messages: VecDeque, unhandled: Option, ka_expire: Instant, ka_timer: Option, } -enum Message { +enum DispatcherMessage { Item(Request), Error(Response), } @@ -63,8 +63,8 @@ enum Message { enum State { None, ServiceCall(S::Future), - SendResponse(Option<(OutMessage, Body)>), - SendPayload(Option, Option), + SendResponse(Option<(Message, Body)>), + SendPayload(Option, Option>), } impl State { @@ -176,11 +176,12 @@ where State::None => loop { break if let Some(msg) = self.messages.pop_front() { match msg { - Message::Item(req) => Some(self.handle_request(req)?), - Message::Error(res) => Some(State::SendResponse(Some(( - OutMessage::Response(res), - Body::Empty, - )))), + DispatcherMessage::Item(req) => { + Some(self.handle_request(req)?) + } + DispatcherMessage::Error(res) => Some(State::SendResponse( + Some((Message::Item(res), Body::Empty)), + )), } } else { None @@ -196,10 +197,7 @@ where .get_codec_mut() .prepare_te(&mut res); let body = res.replace_body(Body::Empty); - Some(State::SendResponse(Some(( - OutMessage::Response(res), - body, - )))) + Some(State::SendResponse(Some((Message::Item(res), body)))) } Async::NotReady => None, } @@ -216,9 +214,9 @@ where self.flags.remove(Flags::FLUSHED); match body { Body::Empty => Some(State::None), - Body::Binary(bin) => Some(State::SendPayload( + Body::Binary(mut bin) => Some(State::SendPayload( None, - Some(OutMessage::Chunk(bin.into())), + Some(Message::Chunk(Some(bin.take()))), )), Body::Streaming(stream) => { Some(State::SendPayload(Some(stream), None)) @@ -257,7 +255,7 @@ where .framed .as_mut() .unwrap() - .start_send(OutMessage::Chunk(Some(item.into()))) + .start_send(Message::Chunk(Some(item.into()))) { Ok(AsyncSink::Ready) => { self.flags.remove(Flags::FLUSHED); @@ -271,7 +269,7 @@ where }, Ok(Async::Ready(None)) => Some(State::SendPayload( None, - Some(OutMessage::Chunk(None)), + Some(Message::Chunk(None)), )), Ok(Async::NotReady) => return Ok(()), // Err(err) => return Err(DispatchError::Io(err)), @@ -312,7 +310,7 @@ where .get_codec_mut() .prepare_te(&mut res); let body = res.replace_body(Body::Empty); - Ok(State::SendResponse(Some((OutMessage::Response(res), body)))) + Ok(State::SendResponse(Some((Message::Item(res), body)))) } Async::NotReady => Ok(State::ServiceCall(task)), } @@ -333,14 +331,20 @@ where self.flags.insert(Flags::STARTED); match msg { - InMessage::Message(req, payload) => { - match payload { - InMessageType::Payload => { + Message::Item(req) => { + match self + .framed + .as_ref() + .unwrap() + .get_codec() + .message_type() + { + MessageType::Payload => { let (ps, pl) = Payload::new(false); *req.inner.payload.borrow_mut() = Some(pl); self.payload = Some(ps); } - InMessageType::Unhandled => { + MessageType::Unhandled => { self.unhandled = Some(req); return Ok(updated); } @@ -351,10 +355,10 @@ where if self.state.is_empty() { self.state = self.handle_request(req)?; } else { - self.messages.push_back(Message::Item(req)); + self.messages.push_back(DispatcherMessage::Item(req)); } } - InMessage::Chunk(Some(chunk)) => { + Message::Chunk(Some(chunk)) => { if let Some(ref mut payload) = self.payload { payload.feed_data(chunk); } else { @@ -362,19 +366,19 @@ where "Internal server error: unexpected payload chunk" ); self.flags.insert(Flags::DISCONNECTED); - self.messages.push_back(Message::Error( + self.messages.push_back(DispatcherMessage::Error( Response::InternalServerError().finish(), )); self.error = Some(DispatchError::InternalError); } } - InMessage::Chunk(None) => { + Message::Chunk(None) => { if let Some(mut payload) = self.payload.take() { payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); self.flags.insert(Flags::DISCONNECTED); - self.messages.push_back(Message::Error( + self.messages.push_back(DispatcherMessage::Error( Response::InternalServerError().finish(), )); self.error = Some(DispatchError::InternalError); @@ -398,8 +402,9 @@ where } // Malformed requests should be responded with 400 - self.messages - .push_back(Message::Error(Response::BadRequest().finish())); + self.messages.push_back(DispatcherMessage::Error( + Response::BadRequest().finish(), + )); self.flags.insert(Flags::DISCONNECTED); self.error = Some(e.into()); break; @@ -443,9 +448,7 @@ where trace!("Slow request timeout"); self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); self.state = State::SendResponse(Some(( - OutMessage::Response( - Response::RequestTimeout().finish(), - ), + Message::Item(Response::RequestTimeout().finish()), Body::Empty, ))); } diff --git a/src/h1/encoder.rs b/src/h1/encoder.rs index ea11f11fd..63ba05d04 100644 --- a/src/h1/encoder.rs +++ b/src/h1/encoder.rs @@ -9,6 +9,7 @@ use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH}; use http::{StatusCode, Version}; use body::{Binary, Body}; +use client::ClientRequest; use header::ContentEncoding; use http::Method; use request::Request; @@ -165,6 +166,39 @@ impl ResponseEncoder { } } +#[derive(Debug)] +pub(crate) struct RequestEncoder { + head: bool, + pub length: ResponseLength, + pub te: TransferEncoding, +} + +impl Default for RequestEncoder { + fn default() -> Self { + RequestEncoder { + head: false, + length: ResponseLength::None, + te: TransferEncoding::empty(), + } + } +} + +impl RequestEncoder { + /// Encode message + pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + self.te.encode(msg, buf) + } + + /// Encode eof + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + self.te.encode_eof(buf) + } + + pub fn update(&mut self, resp: &mut ClientRequest, head: bool, version: Version) { + self.head = head; + } +} + /// Encoders to handle different Transfer-Encodings. #[derive(Debug)] pub(crate) struct TransferEncoding { diff --git a/src/h1/mod.rs b/src/h1/mod.rs index 1ac65912e..c33b193f5 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -1,13 +1,16 @@ //! HTTP/1 implementation use actix_net::codec::Framed; +use bytes::Bytes; +mod client; mod codec; mod decoder; mod dispatcher; mod encoder; mod service; -pub use self::codec::{Codec, InMessage, InMessageType, OutMessage}; +pub use self::client::ClientCodec; +pub use self::codec::Codec; pub use self::decoder::{PayloadDecoder, RequestDecoder}; pub use self::dispatcher::Dispatcher; pub use self::service::{H1Service, H1ServiceHandler, OneRequest}; @@ -20,3 +23,26 @@ pub enum H1ServiceResult { Shutdown(T), Unhandled(Request, Framed), } + +#[derive(Debug)] +/// Codec message +pub enum Message { + /// Http message + Item(T), + /// Payload chunk + Chunk(Option), +} + +impl From for Message { + fn from(item: T) -> Self { + Message::Item(item) + } +} + +/// Incoming request type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageType { + None, + Payload, + Unhandled, +} diff --git a/src/h1/service.rs b/src/h1/service.rs index 096cb3016..d691ae75d 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -13,9 +13,9 @@ use error::{DispatchError, ParseError}; use request::Request; use response::Response; -use super::codec::{Codec, InMessage}; +use super::codec::Codec; use super::dispatcher::Dispatcher; -use super::H1ServiceResult; +use super::{H1ServiceResult, Message}; /// `NewService` implementation for HTTP1 transport pub struct H1Service { @@ -344,10 +344,10 @@ where fn poll(&mut self) -> Poll { match self.framed.as_mut().unwrap().poll()? { Async::Ready(Some(req)) => match req { - InMessage::Message(req, _) => { + Message::Item(req) => { Ok(Async::Ready((req, self.framed.take().unwrap()))) } - InMessage::Chunk(_) => unreachable!("Something is wrong"), + Message::Chunk(_) => unreachable!("Something is wrong"), }, Async::Ready(None) => Err(ParseError::Incomplete), Async::NotReady => Ok(Async::NotReady), diff --git a/src/lib.rs b/src/lib.rs index c9cebb47d..572c23ae1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,6 +85,7 @@ extern crate http as modhttp; extern crate httparse; extern crate mime; extern crate net2; +extern crate percent_encoding; extern crate rand; extern crate serde; extern crate serde_json; @@ -95,12 +96,14 @@ extern crate tokio_current_thread; extern crate tokio_io; extern crate tokio_tcp; extern crate tokio_timer; +extern crate url as urlcrate; #[cfg(test)] #[macro_use] extern crate serde_derive; mod body; +pub mod client; mod config; mod extensions; mod header; diff --git a/src/request.rs b/src/request.rs index ef28e3694..ad0486395 100644 --- a/src/request.rs +++ b/src/request.rs @@ -3,8 +3,9 @@ use std::collections::VecDeque; use std::fmt; use std::rc::Rc; -use http::{header, HeaderMap, Method, Uri, Version}; +use http::{header, HeaderMap, Method, StatusCode, Uri, Version}; +use client::ClientResponse; use extensions::Extensions; use httpmessage::HttpMessage; use payload::Payload; @@ -17,23 +18,24 @@ bitflags! { } } -/// Request's context +/// Request pub struct Request { - pub(crate) inner: Rc, + pub(crate) inner: Rc, } -pub(crate) struct InnerRequest { +pub(crate) struct Message { pub(crate) version: Version, + pub(crate) status: StatusCode, pub(crate) method: Method, pub(crate) url: Url, pub(crate) flags: Cell, pub(crate) headers: HeaderMap, pub(crate) extensions: RefCell, pub(crate) payload: RefCell>, - pool: &'static RequestPool, + pub(crate) pool: &'static MessagePool, } -impl InnerRequest { +impl Message { #[inline] /// Reset request instance pub fn reset(&mut self) { @@ -64,15 +66,16 @@ impl HttpMessage for Request { impl Request { /// Create new Request instance pub fn new() -> Request { - Request::with_pool(RequestPool::pool()) + Request::with_pool(MessagePool::pool()) } /// Create new Request instance with pool - pub(crate) fn with_pool(pool: &'static RequestPool) -> Request { + pub(crate) fn with_pool(pool: &'static MessagePool) -> Request { Request { - inner: Rc::new(InnerRequest { + inner: Rc::new(Message { pool, method: Method::GET, + status: StatusCode::OK, url: Url::default(), version: Version::HTTP_11, headers: HeaderMap::with_capacity(16), @@ -84,12 +87,12 @@ impl Request { } #[inline] - pub(crate) fn inner(&self) -> &InnerRequest { + pub(crate) fn inner(&self) -> &Message { self.inner.as_ref() } #[inline] - pub(crate) fn inner_mut(&mut self) -> &mut InnerRequest { + pub(crate) fn inner_mut(&mut self) -> &mut Message { Rc::get_mut(&mut self.inner).expect("Multiple copies exist") } @@ -201,24 +204,24 @@ impl fmt::Debug for Request { } /// Request's objects pool -pub(crate) struct RequestPool(RefCell>>); +pub(crate) struct MessagePool(RefCell>>); -thread_local!(static POOL: &'static RequestPool = RequestPool::create()); +thread_local!(static POOL: &'static MessagePool = MessagePool::create()); -impl RequestPool { - fn create() -> &'static RequestPool { - let pool = RequestPool(RefCell::new(VecDeque::with_capacity(128))); +impl MessagePool { + fn create() -> &'static MessagePool { + let pool = MessagePool(RefCell::new(VecDeque::with_capacity(128))); Box::leak(Box::new(pool)) } /// Get default request's pool - pub fn pool() -> &'static RequestPool { + pub fn pool() -> &'static MessagePool { POOL.with(|p| *p) } /// Get Request object #[inline] - pub fn get(pool: &'static RequestPool) -> Request { + pub fn get_request(pool: &'static MessagePool) -> Request { if let Some(mut msg) = pool.0.borrow_mut().pop_front() { if let Some(r) = Rc::get_mut(&mut msg) { r.reset(); @@ -228,9 +231,21 @@ impl RequestPool { Request::with_pool(pool) } + /// Get Client Response object + #[inline] + pub fn get_response(pool: &'static MessagePool) -> ClientResponse { + if let Some(mut msg) = pool.0.borrow_mut().pop_front() { + if let Some(r) = Rc::get_mut(&mut msg) { + r.reset(); + } + return ClientResponse { inner: msg }; + } + ClientResponse::with_pool(pool) + } + #[inline] /// Release request instance - pub(crate) fn release(&self, msg: Rc) { + pub(crate) fn release(&self, msg: Rc) { let v = &mut self.0.borrow_mut(); if v.len() < 128 { v.push_front(msg); diff --git a/src/service.rs b/src/service.rs index 4467087bc..a28529145 100644 --- a/src/service.rs +++ b/src/service.rs @@ -8,7 +8,7 @@ use futures::{Async, AsyncSink, Future, Poll, Sink}; use tokio_io::AsyncWrite; use error::ResponseError; -use h1::{Codec, OutMessage}; +use h1::{Codec, Message}; use response::Response; pub struct SendError(PhantomData<(T, R, E)>); @@ -59,7 +59,7 @@ where Ok(r) => Either::A(ok(r)), Err((e, framed)) => Either::B(SendErrorFut { framed: Some(framed), - res: Some(OutMessage::Response(e.error_response())), + res: Some(Message::Item(e.error_response())), err: Some(e), _t: PhantomData, }), @@ -68,7 +68,7 @@ where } pub struct SendErrorFut { - res: Option, + res: Option>, framed: Option>, err: Option, _t: PhantomData, @@ -149,14 +149,14 @@ where fn call(&mut self, (res, framed): Self::Request) -> Self::Future { SendResponseFut { - res: Some(OutMessage::Response(res)), + res: Some(Message::Item(res)), framed: Some(framed), } } } pub struct SendResponseFut { - res: Option, + res: Option>, framed: Option>, } diff --git a/src/ws/client/connect.rs b/src/ws/client/connect.rs new file mode 100644 index 000000000..575b4e4d8 --- /dev/null +++ b/src/ws/client/connect.rs @@ -0,0 +1,95 @@ +//! Http client request +use std::str; + +use cookie::Cookie; +use http::header::{HeaderName, HeaderValue}; +use http::{Error as HttpError, HttpTryFrom}; + +use client::{ClientRequest, ClientRequestBuilder}; +use header::IntoHeaderValue; + +use super::ClientError; + +/// `WebSocket` connection +pub struct Connect { + pub(super) request: ClientRequestBuilder, + pub(super) err: Option, + pub(super) http_err: Option, + pub(super) origin: Option, + pub(super) protocols: Option, + pub(super) max_size: usize, + pub(super) server_mode: bool, +} + +impl Connect { + /// Create new websocket connection + pub fn new>(uri: S) -> Connect { + let mut cl = Connect { + request: ClientRequest::build(), + err: None, + http_err: None, + origin: None, + protocols: None, + max_size: 65_536, + server_mode: false, + }; + cl.request.uri(uri.as_ref()); + cl + } + + /// Set supported websocket protocols + pub fn protocols(mut self, protos: U) -> Self + where + U: IntoIterator + 'static, + V: AsRef, + { + let mut protos = protos + .into_iter() + .fold(String::new(), |acc, s| acc + s.as_ref() + ","); + protos.pop(); + self.protocols = Some(protos); + self + } + + /// Set cookie for handshake request + pub fn cookie(mut self, cookie: Cookie) -> Self { + self.request.cookie(cookie); + self + } + + /// Set request Origin + pub fn origin(mut self, origin: V) -> Self + where + HeaderValue: HttpTryFrom, + { + match HeaderValue::try_from(origin) { + Ok(value) => self.origin = Some(value), + Err(e) => self.http_err = Some(e.into()), + } + self + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Disable payload masking. By default ws client masks frame payload. + pub fn server_mode(mut self) -> Self { + self.server_mode = true; + self + } + + /// Set request header + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + self.request.header(key, value); + self + } +} diff --git a/src/ws/client/error.rs b/src/ws/client/error.rs new file mode 100644 index 000000000..62a8800aa --- /dev/null +++ b/src/ws/client/error.rs @@ -0,0 +1,87 @@ +//! Http client request +use std::io; + +use actix_net::connector::ConnectorError; +use http::header::HeaderValue; +use http::StatusCode; + +use error::ParseError; +use http::Error as HttpError; +use ws::ProtocolError; + +/// Websocket client error +#[derive(Fail, Debug)] +pub enum ClientError { + /// Invalid url + #[fail(display = "Invalid url")] + InvalidUrl, + /// Invalid response status + #[fail(display = "Invalid response status")] + InvalidResponseStatus(StatusCode), + /// Invalid upgrade header + #[fail(display = "Invalid upgrade header")] + InvalidUpgradeHeader, + /// Invalid connection header + #[fail(display = "Invalid connection header")] + InvalidConnectionHeader(HeaderValue), + /// Missing CONNECTION header + #[fail(display = "Missing CONNECTION header")] + MissingConnectionHeader, + /// Missing SEC-WEBSOCKET-ACCEPT header + #[fail(display = "Missing SEC-WEBSOCKET-ACCEPT header")] + MissingWebSocketAcceptHeader, + /// Invalid challenge response + #[fail(display = "Invalid challenge response")] + InvalidChallengeResponse(String, HeaderValue), + /// Http parsing error + #[fail(display = "Http parsing error")] + Http(HttpError), + // /// Url parsing error + // #[fail(display = "Url parsing error")] + // Url(UrlParseError), + /// Response parsing error + #[fail(display = "Response parsing error")] + ParseError(ParseError), + /// Protocol error + #[fail(display = "{}", _0)] + Protocol(#[cause] ProtocolError), + /// Connect error + #[fail(display = "{:?}", _0)] + Connect(ConnectorError), + /// IO Error + #[fail(display = "{}", _0)] + Io(io::Error), + /// "Disconnected" + #[fail(display = "Disconnected")] + Disconnected, +} + +impl From for ClientError { + fn from(err: HttpError) -> ClientError { + ClientError::Http(err) + } +} + +impl From for ClientError { + fn from(err: ConnectorError) -> ClientError { + ClientError::Connect(err) + } +} + +impl From for ClientError { + fn from(err: ProtocolError) -> ClientError { + ClientError::Protocol(err) + } +} + +impl From for ClientError { + fn from(err: io::Error) -> ClientError { + ClientError::Io(err) + } +} + +impl From for ClientError { + fn from(err: ParseError) -> ClientError { + ClientError::ParseError(err) + } +} diff --git a/src/ws/client/mod.rs b/src/ws/client/mod.rs new file mode 100644 index 000000000..0dbf081c6 --- /dev/null +++ b/src/ws/client/mod.rs @@ -0,0 +1,48 @@ +mod connect; +mod error; +mod service; + +pub use self::connect::Connect; +pub use self::error::ClientError; +pub use self::service::Client; + +#[derive(PartialEq, Hash, Debug, Clone, Copy)] +pub(crate) enum Protocol { + Http, + Https, + Ws, + Wss, +} + +impl Protocol { + fn from(s: &str) -> Option { + match s { + "http" => Some(Protocol::Http), + "https" => Some(Protocol::Https), + "ws" => Some(Protocol::Ws), + "wss" => Some(Protocol::Wss), + _ => None, + } + } + + fn is_http(self) -> bool { + match self { + Protocol::Https | Protocol::Http => true, + _ => false, + } + } + + fn is_secure(self) -> bool { + match self { + Protocol::Https | Protocol::Wss => true, + _ => false, + } + } + + fn port(self) -> u16 { + match self { + Protocol::Http | Protocol::Ws => 80, + Protocol::Https | Protocol::Wss => 443, + } + } +} diff --git a/src/ws/client/service.rs b/src/ws/client/service.rs new file mode 100644 index 000000000..ab0e48035 --- /dev/null +++ b/src/ws/client/service.rs @@ -0,0 +1,270 @@ +//! websockets client +use std::marker::PhantomData; + +use actix_net::codec::Framed; +use actix_net::connector::{ConnectorError, DefaultConnector}; +use actix_net::service::Service; +use base64; +use futures::future::{err, Either, FutureResult}; +use futures::{Async, Future, Poll, Sink, Stream}; +use http::header::{self, HeaderValue}; +use http::{HttpTryFrom, StatusCode}; +use rand; +use sha1::Sha1; +use tokio_io::{AsyncRead, AsyncWrite}; + +use client::ClientResponse; +use h1; +use ws::Codec; + +use super::{ClientError, Connect, Protocol}; + +/// WebSocket's client +pub struct Client +where + T: Service, + T::Response: AsyncRead + AsyncWrite, +{ + connector: T, +} + +impl Client +where + T: Service, + T::Response: AsyncRead + AsyncWrite, +{ + /// Create new websocket's client factory + pub fn new(connector: T) -> Self { + Client { connector } + } +} + +impl Default for Client> { + fn default() -> Self { + Client::new(DefaultConnector::default()) + } +} + +impl Clone for Client +where + T: Service + Clone, + T::Response: AsyncRead + AsyncWrite, +{ + fn clone(&self) -> Self { + Client { + connector: self.connector.clone(), + } + } +} + +impl Service for Client +where + T: Service, + T::Response: AsyncRead + AsyncWrite + 'static, + T::Future: 'static, +{ + type Request = Connect; + type Response = Framed; + type Error = ClientError; + type Future = Either< + FutureResult, + ClientResponseFut, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.connector.poll_ready().map_err(ClientError::from) + } + + fn call(&mut self, mut req: Self::Request) -> Self::Future { + if let Some(e) = req.err.take() { + Either::A(err(e)) + } else if let Some(e) = req.http_err.take() { + Either::A(err(e.into())) + } else { + // origin + if let Some(origin) = req.origin.take() { + req.request.set_header(header::ORIGIN, origin); + } + + req.request.upgrade(); + req.request.set_header(header::UPGRADE, "websocket"); + req.request.set_header(header::CONNECTION, "upgrade"); + req.request.set_header(header::SEC_WEBSOCKET_VERSION, "13"); + + if let Some(protocols) = req.protocols.take() { + req.request + .set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str()); + } + let mut request = match req.request.finish() { + Ok(req) => req, + Err(e) => return Either::A(err(e.into())), + }; + + if request.uri().host().is_none() { + return Either::A(err(ClientError::InvalidUrl)); + } + + // supported protocols + let proto = if let Some(scheme) = request.uri().scheme_part() { + match Protocol::from(scheme.as_str()) { + Some(proto) => proto, + None => return Either::A(err(ClientError::InvalidUrl)), + } + } else { + return Either::A(err(ClientError::InvalidUrl)); + }; + + // Generate a random key for the `Sec-WebSocket-Key` header. + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let sec_key: [u8; 16] = rand::random(); + let key = base64::encode(&sec_key); + + request.headers_mut().insert( + header::SEC_WEBSOCKET_KEY, + HeaderValue::try_from(key.as_str()).unwrap(), + ); + + // prep connection + let host = { + let uri = request.uri(); + format!( + "{}:{}", + uri.host().unwrap(), + uri.port().unwrap_or_else(|| proto.port()) + ) + }; + + let fut = Box::new( + self.connector + .call(host) + .map_err(|e| ClientError::from(e)) + .and_then(move |io| { + // h1 protocol + let framed = Framed::new(io, h1::ClientCodec::default()); + framed + .send(request.into()) + .map_err(|e| ClientError::from(e)) + .and_then(|framed| { + framed + .into_future() + .map_err(|(e, _)| ClientError::from(e)) + }) + }), + ); + + // start handshake + Either::B(ClientResponseFut { + key, + fut, + max_size: req.max_size, + server_mode: req.server_mode, + _t: PhantomData, + }) + } + } +} + +/// Future that implementes client websocket handshake process. +/// +/// It resolves to a `Framed` instance. +pub struct ClientResponseFut +where + T: AsyncRead + AsyncWrite, +{ + fut: Box< + Future< + Item = ( + Option>, + Framed, + ), + Error = ClientError, + >, + >, + key: String, + max_size: usize, + server_mode: bool, + _t: PhantomData, +} + +impl Future for ClientResponseFut +where + T: AsyncRead + AsyncWrite, +{ + type Item = Framed; + type Error = ClientError; + + fn poll(&mut self) -> Poll { + let (item, framed) = try_ready!(self.fut.poll()); + + let res = match item { + Some(h1::Message::Item(res)) => res, + Some(h1::Message::Chunk(_)) => unreachable!(), + None => return Err(ClientError::Disconnected), + }; + + // verify response + if res.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(ClientError::InvalidResponseStatus(res.status())); + } + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = res.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + trace!("Invalid upgrade header"); + return Err(ClientError::InvalidUpgradeHeader); + } + // Check for "CONNECTION" header + if let Some(conn) = res.headers().get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_lowercase().contains("upgrade") { + trace!("Invalid connection header: {}", s); + return Err(ClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + trace!("Invalid connection header: {:?}", conn); + return Err(ClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + trace!("Missing connection header"); + return Err(ClientError::MissingConnectionHeader); + } + + if let Some(key) = res.headers().get(header::SEC_WEBSOCKET_ACCEPT) { + // field is constructed by concatenating /key/ + // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::new(); + sha1.update(self.key.as_ref()); + sha1.update(WS_GUID); + let encoded = base64::encode(&sha1.digest().bytes()); + if key.as_bytes() != encoded.as_bytes() { + trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(ClientError::InvalidChallengeResponse(encoded, key.clone())); + } + } else { + trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(ClientError::MissingWebSocketAcceptHeader); + }; + + // websockets codec + let codec = if self.server_mode { + Codec::new().max_size(self.max_size) + } else { + Codec::new().max_size(self.max_size).client_mode() + }; + + Ok(Async::Ready(framed.into_framed(codec))) + } +} diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 690c56feb..b5a667084 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -10,6 +10,7 @@ use http::{header, Method, StatusCode}; use request::Request; use response::{ConnectionType, Response, ResponseBuilder}; +mod client; mod codec; mod frame; mod mask; @@ -17,6 +18,7 @@ mod proto; mod service; mod transport; +pub use self::client::{Client, ClientError, Connect}; pub use self::codec::{Codec, Frame, Message}; pub use self::frame::Parser; pub use self::proto::{CloseCode, CloseReason, OpCode}; diff --git a/tests/test_ws.rs b/tests/test_ws.rs index a5503c209..85770fa8e 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -11,12 +11,12 @@ use actix::System; use actix_net::codec::Framed; use actix_net::framed::IntoFramed; use actix_net::server::Server; -use actix_net::service::NewServiceExt; +use actix_net::service::{NewServiceExt, Service}; use actix_net::stream::TakeItem; use actix_web::{test, ws as web_ws}; -use bytes::Bytes; -use futures::future::{ok, Either}; -use futures::{Future, Sink, Stream}; +use bytes::{Bytes, BytesMut}; +use futures::future::{lazy, ok, Either}; +use futures::{Future, IntoFuture, Sink, Stream}; use actix_http::{h1, ws, ResponseError, ServiceConfig}; @@ -51,14 +51,14 @@ fn test_simple() { .and_then(TakeItem::new().map_err(|_| ())) .and_then(|(req, framed): (_, Framed<_, _>)| { // validate request - if let Some(h1::InMessage::Message(req, _)) = req { + if let Some(h1::Message::Item(req)) = req { match ws::verify_handshake(&req) { Err(e) => { // validation failed let resp = e.error_response(); Either::A( framed - .send(h1::OutMessage::Response(resp)) + .send(h1::Message::Item(resp)) .map_err(|_| ()) .map(|_| ()), ) @@ -66,7 +66,7 @@ fn test_simple() { Ok(_) => Either::B( // send response framed - .send(h1::OutMessage::Response( + .send(h1::Message::Item( ws::handshake_response(&req).finish(), )).map_err(|_| ()) .and_then(|framed| { @@ -116,4 +116,43 @@ fn test_simple() { ))) ); } + + // client service + let mut client = sys + .block_on(lazy(|| Ok::<_, ()>(ws::Client::default()).into_future())) + .unwrap(); + let framed = sys + .block_on(client.call(ws::Connect::new(format!("http://{}/", addr)))) + .unwrap(); + + let framed = sys + .block_on(framed.send(ws::Message::Text("text".to_string()))) + .unwrap(); + let (item, framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + + let framed = sys + .block_on(framed.send(ws::Message::Binary("text".into()))) + .unwrap(); + let (item, framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + ); + + let framed = sys + .block_on(framed.send(ws::Message::Ping("text".into()))) + .unwrap(); + let (item, framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + + let framed = sys + .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) + .unwrap(); + + let (item, _framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + ) }