1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 07:53:00 +01:00

add ClientRequest and ClientRequestBuilder

This commit is contained in:
Nikolay Kim 2018-01-29 11:39:26 -08:00
parent b6a394a113
commit 6416a796c3
7 changed files with 435 additions and 140 deletions

View File

@ -1,5 +1,7 @@
mod parser;
mod request;
mod response;
pub use self::request::{ClientRequest, ClientRequestBuilder};
pub use self::response::ClientResponse;
pub use self::parser::{HttpResponseParser, HttpResponseParserError};

View File

@ -17,7 +17,7 @@ use super::ClientResponse;
const MAX_BUFFER_SIZE: usize = 131_072;
const MAX_HEADERS: usize = 96;
#[derive(Default)]
pub struct HttpResponseParser {
payload: Option<PayloadInfo>,
}
@ -41,11 +41,6 @@ pub enum HttpResponseParserError {
}
impl HttpResponseParser {
pub fn new() -> HttpResponseParser {
HttpResponseParser {
payload: None,
}
}
fn decode(&mut self, buf: &mut BytesMut) -> Result<Decoding, HttpResponseParserError> {
if let Some(ref mut payload) = self.payload {

384
src/client/request.rs Normal file
View File

@ -0,0 +1,384 @@
use std::{fmt, mem};
use std::io::Write;
use cookie::{Cookie, CookieJar};
use bytes::{BytesMut, BufMut};
use http::{HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use serde_json;
use serde::Serialize;
use body::Body;
use error::Error;
use headers::ContentEncoding;
pub struct ClientRequest {
uri: Uri,
method: Method,
version: Version,
headers: HeaderMap,
body: Body,
chunked: Option<bool>,
encoding: ContentEncoding,
}
impl Default for ClientRequest {
fn default() -> ClientRequest {
ClientRequest {
uri: Uri::default(),
method: Method::default(),
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
body: Body::Empty,
chunked: None,
encoding: ContentEncoding::Auto,
}
}
}
impl ClientRequest {
/// Create client request builder
pub fn build() -> ClientRequestBuilder {
ClientRequestBuilder {
request: Some(ClientRequest::default()),
err: None,
cookies: None,
}
}
/// Get the request uri
#[inline]
pub fn uri(&self) -> &Uri {
&self.uri
}
/// Set client request uri
#[inline]
pub fn set_uri(&mut self, uri: Uri) {
self.uri = uri
}
/// Get the request method
#[inline]
pub fn method(&self) -> &Method {
&self.method
}
/// Set http `Method` for the request
#[inline]
pub fn set_method(&mut self, method: Method) {
self.method = method
}
/// Get the headers from the request
#[inline]
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
/// Get a mutable reference to the headers
#[inline]
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
/// Get body os this response
#[inline]
pub fn body(&self) -> &Body {
&self.body
}
/// Set a body
pub fn set_body<B: Into<Body>>(&mut self, body: B) {
self.body = body.into();
}
/// Set a body and return previous body value
pub fn replace_body<B: Into<Body>>(&mut self, body: B) -> Body {
mem::replace(&mut self.body, body.into())
}
}
impl fmt::Debug for ClientRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = write!(f, "\nClientRequest {:?} {}:{}\n",
self.version, self.method, self.uri);
let _ = write!(f, " headers:\n");
for key in self.headers.keys() {
let vals: Vec<_> = self.headers.get_all(key).iter().collect();
if vals.len() > 1 {
let _ = write!(f, " {:?}: {:?}\n", key, vals);
} else {
let _ = write!(f, " {:?}: {:?}\n", key, vals[0]);
}
}
res
}
}
pub struct ClientRequestBuilder {
request: Option<ClientRequest>,
err: Option<HttpError>,
cookies: Option<CookieJar>,
}
impl ClientRequestBuilder {
/// Set HTTP uri of request.
#[inline]
pub fn uri<U>(&mut self, uri: U) -> &mut Self where Uri: HttpTryFrom<U> {
match Uri::try_from(uri) {
Ok(uri) => {
// set request host header
if let Some(host) = uri.host() {
self.set_header(header::HOST, host);
}
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.uri = uri;
}
},
Err(e) => self.err = Some(e.into(),),
}
self
}
/// Set HTTP method of this request.
#[inline]
pub fn method(&mut self, method: Method) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.method = method;
}
self
}
/// Set HTTP version of this request.
///
/// By default requests's http version depends on network stream
#[inline]
pub fn version(&mut self, version: Version) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.version = version;
}
self
}
/// Set a header.
///
/// ```rust
/// # extern crate http;
/// # extern crate actix_web;
/// # use actix_web::client::*;
/// #
/// use http::header;
///
/// fn main() {
/// let req = ClientRequest::build()
/// .header("X-TEST", "value")
/// .header(header::CONTENT_TYPE, "application/json")
/// .finish().unwrap();
/// }
/// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, HeaderValue: HttpTryFrom<V>
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => {
match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.append(key, value); }
Err(e) => self.err = Some(e.into()),
}
},
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Replace a header.
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, HeaderValue: HttpTryFrom<V>
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => {
match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.insert(key, value); }
Err(e) => self.err = Some(e.into()),
}
},
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Set content encoding.
///
/// By default `ContentEncoding::Identity` is used.
#[inline]
pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.encoding = enc;
}
self
}
/// Enables automatic chunked transfer encoding
#[inline]
pub fn chunked(&mut self) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.chunked = Some(true);
}
self
}
/// Set request's content type
#[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self
where HeaderValue: HttpTryFrom<V>
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); },
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Set content length
#[inline]
pub fn content_length(&mut self, len: u64) -> &mut Self {
let mut wrt = BytesMut::new().writer();
let _ = write!(wrt, "{}", len);
self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze())
}
/// Set a cookie
///
/// ```rust
/// # extern crate actix_web;
/// # use actix_web::*;
/// # use actix_web::httpcodes::*;
/// #
/// use actix_web::headers::Cookie;
///
/// fn index(req: HttpRequest) -> Result<HttpResponse> {
/// Ok(HTTPOk.build()
/// .cookie(
/// Cookie::build("name", "value")
/// .domain("www.rust-lang.org")
/// .path("/")
/// .secure(true)
/// .http_only(true)
/// .finish())
/// .finish()?)
/// }
/// fn main() {}
/// ```
pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self {
if self.cookies.is_none() {
let mut jar = CookieJar::new();
jar.add(cookie.into_owned());
self.cookies = Some(jar)
} else {
self.cookies.as_mut().unwrap().add(cookie.into_owned());
}
self
}
/// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()` method.
pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self {
{
if self.cookies.is_none() {
self.cookies = Some(CookieJar::new())
}
let jar = self.cookies.as_mut().unwrap();
let cookie = cookie.clone().into_owned();
jar.add_original(cookie.clone());
jar.remove(cookie);
}
self
}
/// This method calls provided closure with builder reference if value is true.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where F: FnOnce(&mut ClientRequestBuilder)
{
if value {
f(self);
}
self
}
/// This method calls provided closure with builder reference if value is Some.
pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self
where F: FnOnce(T, &mut ClientRequestBuilder)
{
if let Some(val) = value {
f(val, self);
}
self
}
/// Set a body and generate `ClientRequest`.
///
/// `ClientRequestBuilder` can not be used after this call.
pub fn body<B: Into<Body>>(&mut self, body: B) -> Result<ClientRequest, HttpError> {
if let Some(e) = self.err.take() {
return Err(e)
}
let mut request = self.request.take().expect("cannot reuse request builder");
// set cookies
if let Some(ref jar) = self.cookies {
for cookie in jar.delta() {
request.headers.append(
header::SET_COOKIE,
HeaderValue::from_str(&cookie.to_string())?);
}
}
request.body = body.into();
Ok(request)
}
/// Set a json body and generate `ClientRequest`
///
/// `ClientRequestBuilder` can not be used after this call.
pub fn json<T: Serialize>(&mut self, value: T) -> Result<ClientRequest, Error> {
let body = serde_json::to_string(&value)?;
let contains = if let Some(parts) = parts(&mut self.request, &self.err) {
parts.headers.contains_key(header::CONTENT_TYPE)
} else {
true
};
if !contains {
self.header(header::CONTENT_TYPE, "application/json");
}
Ok(self.body(body)?)
}
/// Set an empty body and generate `ClientRequest`
///
/// `ClientRequestBuilder` can not be used after this call.
pub fn finish(&mut self) -> Result<ClientRequest, HttpError> {
self.body(Body::Empty)
}
}
#[inline]
fn parts<'a>(parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>)
-> Option<&'a mut ClientRequest>
{
if err.is_some() {
return None
}
parts.as_mut()
}

View File

@ -83,37 +83,37 @@ impl HttpResponse {
self.get_ref().error.as_ref()
}
/// Get the HTTP version of this response.
/// Get the HTTP version of this response
#[inline]
pub fn version(&self) -> Option<Version> {
self.get_ref().version
}
/// Get the headers from the response.
/// Get the headers from the response
#[inline]
pub fn headers(&self) -> &HeaderMap {
&self.get_ref().headers
}
/// Get a mutable reference to the headers.
/// Get a mutable reference to the headers
#[inline]
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.get_mut().headers
}
/// Get the status from the server.
/// Get the response status code
#[inline]
pub fn status(&self) -> StatusCode {
self.get_ref().status
}
/// Set the `StatusCode` for this response.
/// Set the `StatusCode` for this response
#[inline]
pub fn status_mut(&mut self) -> &mut StatusCode {
&mut self.get_mut().status
}
/// Get custom reason for the response.
/// Get custom reason for the response
#[inline]
pub fn reason(&self) -> &str {
if let Some(reason) = self.get_ref().reason {
@ -123,7 +123,7 @@ impl HttpResponse {
}
}
/// Set the custom reason for the response.
/// Set the custom reason for the response
#[inline]
pub fn set_reason(&mut self, reason: &'static str) -> &mut Self {
self.get_mut().reason = Some(reason);

View File

@ -109,7 +109,8 @@ mod resource;
mod handler;
mod pipeline;
mod client;
#[doc(hidden)]
pub mod client;
pub mod fs;
pub mod ws;

View File

@ -1,28 +1,27 @@
//! Http client request
use std::{fmt, io, str};
use std::{io, str};
use std::rc::Rc;
use std::time::Duration;
use std::cell::UnsafeCell;
use base64;
use rand;
use cookie::{Cookie, CookieJar};
use cookie::Cookie;
use bytes::BytesMut;
use http::{Method, Version, HeaderMap, HttpTryFrom, StatusCode, Error as HttpError};
use http::{HttpTryFrom, StatusCode, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use url::Url;
use sha1::Sha1;
use futures::{Async, Future, Poll, Stream};
// use futures::unsync::oneshot;
use tokio_core::net::TcpStream;
use body::{Body, Binary};
use body::Binary;
use error::UrlParseError;
use headers::ContentEncoding;
use server::shared::SharedBytes;
use server::{utils, IoStream};
use client::{HttpResponseParser, HttpResponseParserError};
use client::{ClientRequest, ClientRequestBuilder,
HttpResponseParser, HttpResponseParserError};
use super::Message;
use super::proto::{CloseCode, OpCode};
@ -91,36 +90,23 @@ type WsFuture<T> = Future<Item=(WsReader<T>, WsWriter<T>), Error=WsClientError>;
/// Websockt client
pub struct WsClient {
request: Option<ClientRequest>,
request: ClientRequestBuilder,
err: Option<WsClientError>,
http_err: Option<HttpError>,
cookies: Option<CookieJar>,
origin: Option<HeaderValue>,
protocols: Option<String>,
}
impl WsClient {
pub fn new<S: AsRef<str>>(url: S) -> WsClient {
pub fn new<S: AsRef<str>>(uri: S) -> WsClient {
let mut cl = WsClient {
request: None,
request: ClientRequest::build(),
err: None,
http_err: None,
cookies: None,
origin: None,
protocols: None };
match Url::parse(url.as_ref()) {
Ok(url) => {
if url.scheme() != "http" && url.scheme() != "https" &&
url.scheme() != "ws" && url.scheme() != "wss" || !url.has_host() {
cl.err = Some(WsClientError::InvalidUrl);
} else {
cl.request = Some(ClientRequest::new(Method::GET, url));
}
},
Err(err) => cl.err = Some(err.into()),
}
cl.request.uri(uri.as_ref());
cl
}
@ -136,13 +122,7 @@ impl WsClient {
}
pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self {
if self.cookies.is_none() {
let mut jar = CookieJar::new();
jar.add(cookie.into_owned());
self.cookies = Some(jar)
} else {
self.cookies.as_mut().unwrap().add(cookie.into_owned());
}
self.request.cookie(cookie);
self
}
@ -158,20 +138,9 @@ impl WsClient {
}
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>
where HeaderName: HttpTryFrom<K>, HeaderValue: HttpTryFrom<V>
{
if let Some(parts) = parts(&mut self.request, &self.err, &self.http_err) {
match HeaderName::try_from(key) {
Ok(key) => {
match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.append(key, value); }
Err(e) => self.http_err = Some(e.into()),
}
},
Err(e) => self.http_err = Some(e.into()),
};
}
self.request.header(key, value);
self
}
@ -182,37 +151,35 @@ impl WsClient {
if let Some(e) = self.http_err.take() {
return Err(e.into())
}
let mut request = self.request.take().expect("cannot reuse request builder");
// headers
if let Some(ref jar) = self.cookies {
for cookie in jar.delta() {
request.headers.append(
header::SET_COOKIE,
HeaderValue::from_str(&cookie.to_string()).map_err(HttpError::from)?);
}
}
// origin
if let Some(origin) = self.origin.take() {
request.headers.insert(header::ORIGIN, origin);
self.request.set_header(header::ORIGIN, origin);
}
request.headers.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
request.headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
request.headers.insert(
HeaderName::try_from("SEC-WEBSOCKET-VERSION").unwrap(),
HeaderValue::from_static("13"));
self.request.set_header(header::UPGRADE, "websocket");
self.request.set_header(header::CONNECTION, "upgrade");
self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
if let Some(protocols) = self.protocols.take() {
request.headers.insert(
HeaderName::try_from("SEC-WEBSOCKET-PROTOCOL").unwrap(),
HeaderValue::try_from(protocols.as_str()).unwrap());
self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
}
let request = self.request.finish()?;
if request.uri().host().is_none() {
return Err(WsClientError::InvalidUrl)
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return Err(WsClientError::InvalidUrl);
}
} else {
return Err(WsClientError::InvalidUrl);
}
let connect = TcpConnector::new(
request.url.host_str().unwrap(),
request.url.port().unwrap_or(80), Duration::from_secs(5));
request.uri().host().unwrap(),
request.uri().port().unwrap_or(80), Duration::from_secs(5));
Ok(Box::new(
connect
@ -221,60 +188,6 @@ impl WsClient {
}
}
#[inline]
fn parts<'a>(parts: &'a mut Option<ClientRequest>,
err: &Option<WsClientError>,
http_err: &Option<HttpError>) -> Option<&'a mut ClientRequest>
{
if err.is_some() || http_err.is_some() {
return None
}
parts.as_mut()
}
pub(crate) struct ClientRequest {
pub url: Url,
pub method: Method,
pub version: Version,
pub headers: HeaderMap,
pub body: Body,
pub chunked: Option<bool>,
pub encoding: ContentEncoding,
}
impl ClientRequest {
#[inline]
fn new(method: Method, url: Url) -> ClientRequest {
ClientRequest {
url: url,
method: method,
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
body: Body::Empty,
chunked: None,
encoding: ContentEncoding::Auto,
}
}
}
impl fmt::Debug for ClientRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = write!(f, "\nClientRequest {:?} {}:{}\n",
self.version, self.method, self.url);
let _ = write!(f, " headers:\n");
for key in self.headers.keys() {
let vals: Vec<_> = self.headers.get_all(key).iter().collect();
if vals.len() > 1 {
let _ = write!(f, " {:?}: {:?}\n", key, vals);
} else {
let _ = write!(f, " {:?}: {:?}\n", key, vals[0]);
}
}
res
}
}
struct WsInner<T> {
stream: T,
writer: Writer,
@ -299,14 +212,14 @@ impl<T: IoStream> WsHandshake<T> {
let sec_key: [u8; 16] = rand::random();
let key = base64::encode(&sec_key);
request.headers.insert(
request.headers_mut().insert(
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
HeaderValue::try_from(key.as_str()).unwrap());
let inner = WsInner {
stream: stream,
writer: Writer::new(SharedBytes::default()),
parser: HttpResponseParser::new(),
parser: HttpResponseParser::default(),
parser_buf: BytesMut::new(),
closed: false,
error_sent: false,
@ -370,7 +283,7 @@ impl<T: IoStream> Future for WsHandshake<T> {
{
// ... field is constructed by concatenating /key/ ...
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut sha1 = Sha1::new();
sha1.update(self.key.as_ref());
sha1.update(WS_GUID);

View File

@ -9,7 +9,7 @@ use body::Binary;
use server::{WriterState, MAX_WRITE_BUFFER_SIZE};
use server::shared::SharedBytes;
use super::client::ClientRequest;
use client::ClientRequest;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
@ -82,17 +82,17 @@ impl Writer {
// render message
{
let buffer = self.buffer.get_mut();
buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE);
buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE);
// status line
// helpers::write_status_line(version, msg.status().as_u16(), &mut buffer);
// buffer.extend_from_slice(msg.reason().as_bytes());
buffer.extend_from_slice(b"GET ");
buffer.extend_from_slice(msg.url.path().as_ref());
buffer.extend_from_slice(msg.uri().path().as_ref());
buffer.extend_from_slice(b" HTTP/1.1\r\n");
// write headers
for (key, value) in &msg.headers {
for (key, value) in msg.headers() {
let v = value.as_ref();
let k = key.as_str().as_bytes();
buffer.reserve(k.len() + v.len() + 4);