diff --git a/src/client/mod.rs b/src/client/mod.rs index dad0c3f3..de9bfdcc 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,7 +1,9 @@ mod parser; mod request; mod response; +mod writer; +pub(crate) use self::writer::HttpClientWriter; pub use self::request::{ClientRequest, ClientRequestBuilder}; pub use self::response::ClientResponse; pub use self::parser::{HttpResponseParser, HttpResponseParserError}; diff --git a/src/client/parser.rs b/src/client/parser.rs index 73ef2278..9e81692e 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -13,6 +13,7 @@ use server::h1::{Decoder, chunked}; use server::encoding::PayloadType; use super::ClientResponse; +use super::response::ClientMessage; const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; @@ -225,10 +226,16 @@ impl HttpResponseParser { decoder: decoder, }; Ok(Async::Ready( - (ClientResponse::new(status, version, hdrs, Some(payload)), Some(info)))) + (ClientResponse::new( + ClientMessage{ + status: status, version: version, + headers: hdrs, cookies: None, payload: Some(payload)}), Some(info)))) } else { Ok(Async::Ready( - (ClientResponse::new(status, version, hdrs, None), None))) + (ClientResponse::new( + ClientMessage{ + status: status, version: version, + headers: hdrs, cookies: None, payload: None}), None))) } } } diff --git a/src/client/request.rs b/src/client/request.rs index c441947f..99abad8e 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -12,7 +12,7 @@ use body::Body; use error::Error; use headers::ContentEncoding; - +/// An HTTP Client Request pub struct ClientRequest { uri: Uri, method: Method, @@ -38,6 +38,44 @@ impl Default for ClientRequest { } } +impl ClientRequest { + + /// Create request builder for `GET` request + pub fn get(uri: U) -> ClientRequestBuilder where Uri: HttpTryFrom { + let mut builder = ClientRequest::build(); + builder.method(Method::GET).uri(uri); + builder + } + + /// Create request builder for `HEAD` request + pub fn head(uri: U) -> ClientRequestBuilder where Uri: HttpTryFrom { + let mut builder = ClientRequest::build(); + builder.method(Method::HEAD).uri(uri); + builder + } + + /// Create request builder for `POST` request + pub fn post(uri: U) -> ClientRequestBuilder where Uri: HttpTryFrom { + let mut builder = ClientRequest::build(); + builder.method(Method::POST).uri(uri); + builder + } + + /// Create request builder for `PUT` request + pub fn put(uri: U) -> ClientRequestBuilder where Uri: HttpTryFrom { + let mut builder = ClientRequest::build(); + builder.method(Method::PUT).uri(uri); + builder + } + + /// Create request builder for `DELETE` request + pub fn delete(uri: U) -> ClientRequestBuilder where Uri: HttpTryFrom { + let mut builder = ClientRequest::build(); + builder.method(Method::DELETE).uri(uri); + builder + } +} + impl ClientRequest { /// Create client request builder @@ -73,6 +111,18 @@ impl ClientRequest { self.method = method } + /// Get http version for the request + #[inline] + pub fn version(&self) -> Version { + self.version + } + + /// Set http `Version` for the request + #[inline] + pub fn set_version(&mut self, version: Version) { + self.version = version + } + /// Get the headers from the request #[inline] pub fn headers(&self) -> &HeaderMap { @@ -120,6 +170,10 @@ impl fmt::Debug for ClientRequest { } +/// An HTTP client request builder +/// +/// This type can be used to construct an instance of `ClientRequest` through a +/// builder-like pattern. pub struct ClientRequestBuilder { request: Option, err: Option, @@ -165,7 +219,10 @@ impl ClientRequestBuilder { self } - /// Set a header. + /// Add a header. + /// + /// Header get appended to existing header. + /// To override header use `set_header()` method. /// /// ```rust /// # extern crate http; @@ -266,9 +323,10 @@ impl ClientRequestBuilder { /// # use actix_web::httpcodes::*; /// # /// use actix_web::headers::Cookie; + /// use actix_web::client::ClientRequest; /// - /// fn index(req: HttpRequest) -> Result { - /// Ok(HTTPOk.build() + /// fn main() { + /// let req = ClientRequest::build() /// .cookie( /// Cookie::build("name", "value") /// .domain("www.rust-lang.org") @@ -276,9 +334,8 @@ impl ClientRequestBuilder { /// .secure(true) /// .http_only(true) /// .finish()) - /// .finish()?) + /// .finish().unwrap(); /// } - /// fn main() {} /// ``` pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { if self.cookies.is_none() { diff --git a/src/client/response.rs b/src/client/response.rs index 0bd8ed60..820d0100 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,70 +1,241 @@ -#![allow(dead_code)] -use std::fmt; +use std::{fmt, str}; +use std::rc::Rc; +use std::cell::UnsafeCell; + +use bytes::{Bytes, BytesMut}; +use cookie::Cookie; +use futures::{Async, Future, Poll, Stream}; +use http_range::HttpRange; use http::{HeaderMap, StatusCode, Version}; -use http::header::HeaderValue; +use http::header::{self, HeaderValue}; +use mime::Mime; +use serde_json; +use serde::de::DeserializeOwned; -use payload::Payload; +use payload::{Payload, ReadAny}; +use multipart::Multipart; +use httprequest::UrlEncoded; +use error::{CookieParseError, ParseError, PayloadError, JsonPayloadError, HttpRangeError}; -pub struct ClientResponse { - /// The response's status - status: StatusCode, - - /// The response's version - version: Version, - - /// The response's headers - headers: HeaderMap, - - payload: Option, +pub(crate) struct ClientMessage { + pub status: StatusCode, + pub version: Version, + pub headers: HeaderMap, + pub cookies: Option>>, + pub payload: Option, } -impl ClientResponse { - pub fn new(status: StatusCode, version: Version, - headers: HeaderMap, payload: Option) -> Self { - ClientResponse { - status: status, version: version, headers: headers, payload: payload +impl Default for ClientMessage { + + fn default() -> ClientMessage { + ClientMessage { + status: StatusCode::OK, + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + cookies: None, + payload: None, } } +} - /// Get the HTTP version of this response. +pub struct ClientResponse(Rc>); + +impl ClientResponse { + + pub(crate) fn new(msg: ClientMessage) -> ClientResponse { + ClientResponse(Rc::new(UnsafeCell::new(msg))) + } + + #[inline] + fn as_ref(&self) -> &ClientMessage { + unsafe{ &*self.0.get() } + } + + #[inline] + #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))] + fn as_mut(&self) -> &mut ClientMessage { + unsafe{ &mut *self.0.get() } + } + + /// Get the HTTP version of this response. #[inline] pub fn version(&self) -> Version { - self.version + self.as_ref().version } /// Get the headers from the response. #[inline] pub fn headers(&self) -> &HeaderMap { - &self.headers + &self.as_ref().headers } /// Get a mutable reference to the headers. #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers + &mut self.as_mut().headers } /// Get the status from the server. #[inline] pub fn status(&self) -> StatusCode { - self.status + self.as_ref().status } /// Set the `StatusCode` for this response. #[inline] - pub fn status_mut(&mut self) -> &mut StatusCode { - &mut self.status + pub fn set_status(&mut self, status: StatusCode) { + self.as_mut().status = status + } + + /// Load request cookies. + pub fn cookies(&self) -> Result<&Vec>, CookieParseError> { + if self.as_ref().cookies.is_none() { + let msg = self.as_mut(); + let mut cookies = Vec::new(); + if let Some(val) = msg.headers.get(header::COOKIE) { + let s = str::from_utf8(val.as_bytes()) + .map_err(CookieParseError::from)?; + for cookie in s.split("; ") { + cookies.push(Cookie::parse_encoded(cookie)?.into_owned()); + } + } + msg.cookies = Some(cookies) + } + Ok(self.as_ref().cookies.as_ref().unwrap()) + } + + /// Return request cookie. + pub fn cookie(&self, name: &str) -> Option<&Cookie> { + if let Ok(cookies) = self.cookies() { + for cookie in cookies { + if cookie.name() == name { + return Some(cookie) + } + } + } + None + } + + /// Read the request content type. If request does not contain + /// *Content-Type* header, empty str get returned. + pub fn content_type(&self) -> &str { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return content_type.split(';').next().unwrap().trim() + } + } + "" + } + + /// Convert the request content type to a known mime type. + pub fn mime_type(&self) -> Option { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return match content_type.parse() { + Ok(mt) => Some(mt), + Err(_) => None + }; + } + } + None + } + + /// Check if request has chunked transfer encoding + pub fn chunked(&self) -> Result { + if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { + if let Ok(s) = encodings.to_str() { + Ok(s.to_lowercase().contains("chunked")) + } else { + Err(ParseError::Header) + } + } else { + Ok(false) + } + } + + /// Parses Range HTTP header string as per RFC 2616. + /// `size` is full size of response (file). + pub fn range(&self, size: u64) -> Result, HttpRangeError> { + if let Some(range) = self.headers().get(header::RANGE) { + HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) + .map_err(|e| e.into()) + } else { + Ok(Vec::new()) + } + } + + /// Returns reference to the associated http payload. + #[inline] + pub fn payload(&self) -> &Payload { + let msg = self.as_mut(); + if msg.payload.is_none() { + msg.payload = Some(Payload::empty()); + } + msg.payload.as_ref().unwrap() + } + + /// Returns mutable reference to the associated http payload. + #[inline] + pub fn payload_mut(&mut self) -> &mut Payload { + let msg = self.as_mut(); + if msg.payload.is_none() { + msg.payload = Some(Payload::empty()); + } + msg.payload.as_mut().unwrap() + } + + /// Load request body. + /// + /// By default only 256Kb payload reads to a memory, then `ResponseBody` + /// resolves to an error. Use `RequestBody::limit()` + /// method to change upper limit. + pub fn body(&self) -> ResponseBody { + ResponseBody::from_response(self) + } + + + /// Return stream to http payload processes as multipart. + /// + /// Content-type: multipart/form-data; + pub fn multipart(&mut self) -> Multipart { + Multipart::from_response(self) + } + + /// Parse `application/x-www-form-urlencoded` encoded body. + /// Return `UrlEncoded` future. It resolves to a `HashMap` which + /// contains decoded parameters. + /// + /// Returns error: + /// + /// * content type is not `application/x-www-form-urlencoded` + /// * transfer encoding is `chunked`. + /// * content-length is greater than 256k + pub fn urlencoded(&self) -> UrlEncoded { + UrlEncoded::from(self.payload().clone(), + self.headers(), + self.chunked().unwrap_or(false)) + } + + /// Parse `application/json` encoded body. + /// Return `JsonResponse` future. It resolves to a `T` value. + /// + /// Returns error: + /// + /// * content type is not `application/json` + /// * content length is greater than 256k + pub fn json(&self) -> JsonResponse { + JsonResponse::from_response(self) } } impl fmt::Debug for ClientResponse { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = write!( - f, "\nClientResponse {:?} {}\n", self.version, self.status); + f, "\nClientResponse {:?} {}\n", self.version(), self.status()); let _ = write!(f, " headers:\n"); - for key in self.headers.keys() { - let vals: Vec<_> = self.headers.get_all(key).iter().collect(); + for key in self.headers().keys() { + let vals: Vec<_> = self.headers().get_all(key).iter().collect(); if vals.len() > 1 { let _ = write!(f, " {:?}: {:?}\n", key, vals); } else { @@ -74,3 +245,160 @@ impl fmt::Debug for ClientResponse { res } } + +impl Clone for ClientResponse { + fn clone(&self) -> ClientResponse { + ClientResponse(self.0.clone()) + } +} + +/// Future that resolves to a complete request body. +pub struct ResponseBody { + pl: ReadAny, + body: BytesMut, + limit: usize, + resp: Option, +} + +impl ResponseBody { + + /// Create `RequestBody` for request. + pub fn from_response(resp: &ClientResponse) -> ResponseBody { + let pl = resp.payload().readany(); + ResponseBody { + pl: pl, + body: BytesMut::new(), + limit: 262_144, + resp: Some(resp.clone()), + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for ResponseBody { + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll { + if let Some(resp) = self.resp.take() { + if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + if len > 262_144 { + return Err(PayloadError::Overflow); + } + } else { + return Err(PayloadError::UnknownLength); + } + } else { + return Err(PayloadError::UnknownLength); + } + } + } + + loop { + return match self.pl.poll() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(None)) => { + Ok(Async::Ready(self.body.take().freeze())) + }, + Ok(Async::Ready(Some(chunk))) => { + if (self.body.len() + chunk.len()) > self.limit { + Err(PayloadError::Overflow) + } else { + self.body.extend_from_slice(&chunk); + continue + } + }, + Err(err) => Err(err), + } + } + } +} + +/// Client response payload json parser that resolves to a deserialized `T` value. +/// +/// Returns error: +/// +/// * content type is not `application/json` +/// * content length is greater than 256k +pub struct JsonResponse{ + limit: usize, + ct: &'static str, + resp: Option, + fut: Option>>, +} + +impl JsonResponse { + + /// Create `JsonBody` for request. + pub fn from_response(resp: &ClientResponse) -> Self { + JsonResponse{ + limit: 262_144, + resp: Some(resp.clone()), + fut: None, + ct: "application/json", + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set allowed content type. + /// + /// By default *application/json* content type is used. Set content type + /// to empty string if you want to disable content type check. + pub fn content_type(mut self, ct: &'static str) -> Self { + self.ct = ct; + self + } +} + +impl Future for JsonResponse { + type Item = T; + type Error = JsonPayloadError; + + fn poll(&mut self) -> Poll { + if let Some(resp) = self.resp.take() { + if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + if len > self.limit { + return Err(JsonPayloadError::Overflow); + } + } else { + return Err(JsonPayloadError::Overflow); + } + } + } + // check content-type + if !self.ct.is_empty() && resp.content_type() != self.ct { + return Err(JsonPayloadError::ContentType) + } + + let limit = self.limit; + let fut = resp.payload().readany() + .from_err() + .fold(BytesMut::new(), move |mut body, chunk| { + if (body.len() + chunk.len()) > limit { + Err(JsonPayloadError::Overflow) + } else { + body.extend_from_slice(&chunk); + Ok(body) + } + }) + .and_then(|body| Ok(serde_json::from_slice::(&body)?)); + self.fut = Some(Box::new(fut)); + } + + self.fut.as_mut().expect("JsonResponse could not be used second time").poll() + } +} diff --git a/src/ws/writer.rs b/src/client/writer.rs similarity index 90% rename from src/ws/writer.rs rename to src/client/writer.rs index 0cee0181..f77d4c98 100644 --- a/src/ws/writer.rs +++ b/src/client/writer.rs @@ -1,5 +1,6 @@ #![allow(dead_code)] use std::io; +use std::fmt::Write; use bytes::BufMut; use futures::{Async, Poll}; use tokio_io::AsyncWrite; @@ -23,17 +24,17 @@ bitflags! { } } -pub(crate) struct Writer { +pub(crate) struct HttpClientWriter { flags: Flags, written: u64, headers_size: u32, buffer: SharedBytes, } -impl Writer { +impl HttpClientWriter { - pub fn new(buf: SharedBytes) -> Writer { - Writer { + pub fn new(buf: SharedBytes) -> HttpClientWriter { + HttpClientWriter { flags: Flags::empty(), written: 0, headers_size: 0, @@ -73,7 +74,7 @@ impl Writer { } } -impl Writer { +impl HttpClientWriter { pub fn start(&mut self, msg: &mut ClientRequest) { // prepare task @@ -85,11 +86,8 @@ impl Writer { buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE); // status line - // helpers::write_status_line(version, msg.status().as_u16(), &mut buffer); - // buffer.extend_from_slice(msg.reason().as_bytes()); - buffer.extend_from_slice(b"GET "); - buffer.extend_from_slice(msg.uri().path().as_ref()); - buffer.extend_from_slice(b" HTTP/1.1\r\n"); + let _ = write!(buffer, "{} {} {:?}\r\n", + msg.method(), msg.uri().path(), msg.version()); // write headers for (key, value) in msg.headers() { diff --git a/src/httprequest.rs b/src/httprequest.rs index e1689219..e22dcb7e 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -541,7 +541,9 @@ impl HttpRequest { /// # fn main() {} /// ``` pub fn urlencoded(&self) -> UrlEncoded { - UrlEncoded::from_request(self) + UrlEncoded::from(self.payload().clone(), + self.headers(), + self.chunked().unwrap_or(false)) } /// Parse `application/json` encoded body. @@ -624,16 +626,16 @@ pub struct UrlEncoded { } impl UrlEncoded { - pub fn from_request(req: &HttpRequest) -> UrlEncoded { + pub fn from(pl: Payload, headers: &HeaderMap, chunked: bool) -> UrlEncoded { let mut encoded = UrlEncoded { - pl: req.payload().clone(), + pl: pl, body: BytesMut::new(), error: None }; - if let Ok(true) = req.chunked() { + if chunked { encoded.error = Some(UrlencodedError::Chunked); - } else if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { + } else if let Some(len) = headers.get(header::CONTENT_LENGTH) { if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { if len > 262_144 { @@ -649,7 +651,7 @@ impl UrlEncoded { // check content type if encoded.error.is_none() { - if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { + if let Some(content_type) = headers.get(header::CONTENT_TYPE) { if let Ok(content_type) = content_type.to_str() { if content_type.to_lowercase() == "application/x-www-form-urlencoded" { return encoded diff --git a/src/lib.rs b/src/lib.rs index 02e2b1b1..4722f8c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,8 @@ #![cfg_attr(actix_nightly, feature( specialization, // for impl ErrorResponse for std::error::Error ))] +#![cfg_attr(feature = "cargo-clippy", allow( + decimal_literal_representation,))] #[macro_use] extern crate log; diff --git a/src/multipart.rs b/src/multipart.rs index aa3f7201..00f2676c 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -14,6 +14,7 @@ use futures::task::{Task, current as current_task}; use error::{ParseError, PayloadError, MultipartError}; use payload::Payload; +use client::ClientResponse; use httprequest::HttpRequest; const MAX_HEADERS: usize = 32; @@ -97,6 +98,19 @@ impl Multipart { } } + /// Create multipart instance for client response. + pub fn from_response(resp: &mut ClientResponse) -> Multipart { + match Multipart::boundary(resp.headers()) { + Ok(boundary) => Multipart::new(boundary, resp.payload().clone()), + Err(err) => + Multipart { + error: Some(err), + safety: Safety::new(), + inner: None, + } + } + } + /// Extract boundary info from headers. pub fn boundary(headers: &HeaderMap) -> Result { if let Some(content_type) = headers.get(header::CONTENT_TYPE) { diff --git a/src/ws/client.rs b/src/ws/client.rs index 2e6094e6..f9536cea 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -21,12 +21,11 @@ use server::shared::SharedBytes; use server::{utils, IoStream}; use client::{ClientRequest, ClientRequestBuilder, - HttpResponseParser, HttpResponseParserError}; + HttpResponseParser, HttpResponseParserError, HttpClientWriter}; use super::Message; use super::proto::{CloseCode, OpCode}; use super::frame::Frame; -use super::writer::Writer; use super::connect::{TcpConnector, TcpConnectorError}; /// Websockt client error @@ -86,7 +85,7 @@ impl From for WsClientError { } } -type WsFuture = Future, WsWriter), Error=WsClientError>; +pub type WsFuture = Future, WsWriter), Error=WsClientError>; /// Websockt client pub struct WsClient { @@ -190,7 +189,7 @@ impl WsClient { struct WsInner { stream: T, - writer: Writer, + writer: HttpClientWriter, parser: HttpResponseParser, parser_buf: BytesMut, closed: bool, @@ -218,7 +217,7 @@ impl WsHandshake { let inner = WsInner { stream: stream, - writer: Writer::new(SharedBytes::default()), + writer: HttpClientWriter::new(SharedBytes::default()), parser: HttpResponseParser::default(), parser_buf: BytesMut::new(), closed: false, diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 6ac87ecd..1cb59a04 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -59,18 +59,15 @@ mod frame; mod proto; mod context; mod mask; +mod client; mod connect; -mod writer; -pub mod client; - -use ws::frame::Frame; -use ws::proto::{hash_key, OpCode}; -pub use ws::proto::CloseCode; -pub use ws::context::WebsocketContext; - -pub use self::client::{WsClient, WsClientError, WsReader, WsWriter}; +use self::frame::Frame; +use self::proto::{hash_key, OpCode}; +pub use self::proto::CloseCode; +pub use self::context::WebsocketContext; +pub use self::client::{WsClient, WsClientError, WsReader, WsWriter, WsFuture}; const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";