1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-23 16:21:06 +01:00

add rustfmt config

This commit is contained in:
Nikolay Kim 2018-04-13 16:02:01 -07:00
parent 95f6277007
commit 113f5ad1a8
91 changed files with 8057 additions and 5509 deletions

View File

@ -3,7 +3,6 @@ extern crate version_check;
use std::{env, fs};
#[cfg(unix)]
fn main() {
println!("cargo:rerun-if-env-changed=USE_SKEPTIC");
@ -11,23 +10,23 @@ fn main() {
if env::var("USE_SKEPTIC").is_ok() {
let _ = fs::remove_file(f);
// generates doc tests for `README.md`.
skeptic::generate_doc_tests(
&[// "README.md",
"guide/src/qs_1.md",
"guide/src/qs_2.md",
"guide/src/qs_3.md",
"guide/src/qs_3_5.md",
"guide/src/qs_4.md",
"guide/src/qs_4_5.md",
"guide/src/qs_5.md",
"guide/src/qs_7.md",
"guide/src/qs_8.md",
"guide/src/qs_9.md",
"guide/src/qs_10.md",
"guide/src/qs_12.md",
"guide/src/qs_13.md",
"guide/src/qs_14.md",
]);
skeptic::generate_doc_tests(&[
// "README.md",
"guide/src/qs_1.md",
"guide/src/qs_2.md",
"guide/src/qs_3.md",
"guide/src/qs_3_5.md",
"guide/src/qs_4.md",
"guide/src/qs_4_5.md",
"guide/src/qs_5.md",
"guide/src/qs_7.md",
"guide/src/qs_8.md",
"guide/src/qs_9.md",
"guide/src/qs_10.md",
"guide/src/qs_12.md",
"guide/src/qs_13.md",
"guide/src/qs_14.md",
]);
} else {
let _ = fs::File::create(f);
}

7
rustfmt.toml Normal file
View File

@ -0,0 +1,7 @@
max_width = 89
reorder_imports = true
reorder_imports_in_group = true
reorder_imported_names = true
wrap_comments = true
fn_args_density = "Compressed"
#use_small_heuristics = false

View File

@ -1,24 +1,24 @@
use std::mem;
use std::rc::Rc;
use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::mem;
use std::rc::Rc;
use http::Method;
use handler::Reply;
use router::{Router, Resource};
use resource::{ResourceHandler};
use handler::{FromRequest, Handler, Responder, RouteHandler, WrapHandler};
use header::ContentEncoding;
use handler::{Handler, RouteHandler, WrapHandler, FromRequest, Responder};
use http::Method;
use httprequest::HttpRequest;
use pipeline::{Pipeline, PipelineHandler, HandlerType};
use middleware::Middleware;
use server::{HttpHandler, IntoHttpHandler, HttpHandlerTask, ServerSettings};
use pipeline::{HandlerType, Pipeline, PipelineHandler};
use resource::ResourceHandler;
use router::{Resource, Router};
use server::{HttpHandler, HttpHandlerTask, IntoHttpHandler, ServerSettings};
#[deprecated(since="0.5.0", note="please use `actix_web::App` instead")]
#[deprecated(since = "0.5.0", note = "please use `actix_web::App` instead")]
pub type Application<S> = App<S>;
/// Application
pub struct HttpApplication<S=()> {
pub struct HttpApplication<S = ()> {
state: Rc<S>,
prefix: String,
prefix_len: usize,
@ -36,28 +36,25 @@ pub(crate) struct Inner<S> {
}
impl<S: 'static> PipelineHandler<S> for Inner<S> {
fn encoding(&self) -> ContentEncoding {
self.encoding
}
fn handle(&mut self, req: HttpRequest<S>, htype: HandlerType) -> Reply {
match htype {
HandlerType::Normal(idx) =>
self.resources[idx].handle(req, Some(&mut self.default)),
HandlerType::Handler(idx) =>
self.handlers[idx].1.handle(req),
HandlerType::Default =>
self.default.handle(req, None)
HandlerType::Normal(idx) => {
self.resources[idx].handle(req, Some(&mut self.default))
}
HandlerType::Handler(idx) => self.handlers[idx].1.handle(req),
HandlerType::Default => self.default.handle(req, None),
}
}
}
impl<S: 'static> HttpApplication<S> {
#[inline]
fn as_ref(&self) -> &Inner<S> {
unsafe{&*self.inner.get()}
unsafe { &*self.inner.get() }
}
#[inline]
@ -70,20 +67,21 @@ impl<S: 'static> HttpApplication<S> {
let &(ref prefix, _) = &inner.handlers[idx];
let m = {
let path = &req.path()[inner.prefix..];
path.starts_with(prefix) && (
path.len() == prefix.len() ||
path.split_at(prefix.len()).1.starts_with('/'))
path.starts_with(prefix)
&& (path.len() == prefix.len()
|| path.split_at(prefix.len()).1.starts_with('/'))
};
if m {
let path: &'static str = unsafe {
mem::transmute(&req.path()[inner.prefix+prefix.len()..]) };
mem::transmute(&req.path()[inner.prefix + prefix.len()..])
};
if path.is_empty() {
req.match_info_mut().add("tail", "");
} else {
req.match_info_mut().add("tail", path.split_at(1).1);
}
return HandlerType::Handler(idx)
return HandlerType::Handler(idx);
}
}
HandlerType::Default
@ -93,7 +91,7 @@ impl<S: 'static> HttpApplication<S> {
#[cfg(test)]
pub(crate) fn run(&mut self, mut req: HttpRequest<S>) -> Reply {
let tp = self.get_handler(&mut req);
unsafe{&mut *self.inner.get()}.handle(req, tp)
unsafe { &mut *self.inner.get() }.handle(req, tp)
}
#[cfg(test)]
@ -103,19 +101,23 @@ impl<S: 'static> HttpApplication<S> {
}
impl<S: 'static> HttpHandler for HttpApplication<S> {
fn handle(&mut self, req: HttpRequest) -> Result<Box<HttpHandlerTask>, HttpRequest> {
let m = {
let path = req.path();
path.starts_with(&self.prefix) && (
path.len() == self.prefix_len ||
path.split_at(self.prefix_len).1.starts_with('/'))
path.starts_with(&self.prefix)
&& (path.len() == self.prefix_len
|| path.split_at(self.prefix_len).1.starts_with('/'))
};
if m {
let mut req = req.with_state(Rc::clone(&self.state), self.router.clone());
let tp = self.get_handler(&mut req);
let inner = Rc::clone(&self.inner);
Ok(Box::new(Pipeline::new(req, Rc::clone(&self.middlewares), inner, tp)))
Ok(Box::new(Pipeline::new(
req,
Rc::clone(&self.middlewares),
inner,
tp,
)))
} else {
Err(req)
}
@ -134,14 +136,14 @@ struct ApplicationParts<S> {
middlewares: Vec<Box<Middleware<S>>>,
}
/// Structure that follows the builder pattern for building application instances.
pub struct App<S=()> {
/// Structure that follows the builder pattern for building application
/// instances.
pub struct App<S = ()> {
parts: Option<ApplicationParts<S>>,
}
impl App<()> {
/// Create application with empty state. Application can
/// Create application with empty state. Application can
/// be configured with a builder-like pattern.
pub fn new() -> App<()> {
App {
@ -155,7 +157,7 @@ impl App<()> {
external: HashMap::new(),
encoding: ContentEncoding::Auto,
middlewares: Vec::new(),
})
}),
}
}
}
@ -166,8 +168,10 @@ impl Default for App<()> {
}
}
impl<S> App<S> where S: 'static {
impl<S> App<S>
where
S: 'static,
{
/// Create application with specified state. Application can be
/// configured with a builder-like pattern.
///
@ -185,7 +189,7 @@ impl<S> App<S> where S: 'static {
external: HashMap::new(),
middlewares: Vec::new(),
encoding: ContentEncoding::Auto,
})
}),
}
}
@ -256,20 +260,22 @@ impl<S> App<S> where S: 'static {
/// }
/// ```
pub fn route<T, F, R>(mut self, path: &str, method: Method, f: F) -> App<S>
where F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
where
F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
{
{
let parts: &mut ApplicationParts<S> = unsafe{
mem::transmute(self.parts.as_mut().expect("Use after finish"))};
let parts: &mut ApplicationParts<S> = unsafe {
mem::transmute(self.parts.as_mut().expect("Use after finish"))
};
// get resource handler
for &mut (ref pattern, ref mut handler) in &mut parts.resources {
if let Some(ref mut handler) = *handler {
if pattern.pattern() == path {
handler.method(method).with(f);
return self
return self;
}
}
}
@ -315,7 +321,8 @@ impl<S> App<S> where S: 'static {
/// }
/// ```
pub fn resource<F, R>(mut self, path: &str, f: F) -> App<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static
where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{
{
let parts = self.parts.as_mut().expect("Use after finish");
@ -334,13 +341,17 @@ impl<S> App<S> where S: 'static {
#[doc(hidden)]
pub fn register_resource(&mut self, path: &str, resource: ResourceHandler<S>) {
let pattern = Resource::new(resource.get_name(), path);
self.parts.as_mut().expect("Use after finish")
.resources.push((pattern, Some(resource)));
self.parts
.as_mut()
.expect("Use after finish")
.resources
.push((pattern, Some(resource)));
}
/// Default resource to be used if no matching route could be found.
pub fn default_resource<F, R>(mut self, f: F) -> App<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static
where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{
{
let parts = self.parts.as_mut().expect("Use after finish");
@ -350,8 +361,7 @@ impl<S> App<S> where S: 'static {
}
/// Set default content encoding. `ContentEncoding::Auto` is set by default.
pub fn default_encoding(mut self, encoding: ContentEncoding) -> App<S>
{
pub fn default_encoding(mut self, encoding: ContentEncoding) -> App<S> {
{
let parts = self.parts.as_mut().expect("Use after finish");
parts.encoding = encoding;
@ -383,7 +393,9 @@ impl<S> App<S> where S: 'static {
/// }
/// ```
pub fn external_resource<T, U>(mut self, name: T, url: U) -> App<S>
where T: AsRef<str>, U: AsRef<str>
where
T: AsRef<str>,
U: AsRef<str>,
{
{
let parts = self.parts.as_mut().expect("Use after finish");
@ -393,7 +405,8 @@ impl<S> App<S> where S: 'static {
}
parts.external.insert(
String::from(name.as_ref()),
Resource::external(name.as_ref(), url.as_ref()));
Resource::external(name.as_ref(), url.as_ref()),
);
}
self
}
@ -419,8 +432,7 @@ impl<S> App<S> where S: 'static {
/// }});
/// }
/// ```
pub fn handler<H: Handler<S>>(mut self, path: &str, handler: H) -> App<S>
{
pub fn handler<H: Handler<S>>(mut self, path: &str, handler: H) -> App<S> {
{
let mut path = path.trim().trim_right_matches('/').to_owned();
if !path.is_empty() && !path.starts_with('/') {
@ -428,15 +440,20 @@ impl<S> App<S> where S: 'static {
}
let parts = self.parts.as_mut().expect("Use after finish");
parts.handlers.push((path, Box::new(WrapHandler::new(handler))));
parts
.handlers
.push((path, Box::new(WrapHandler::new(handler))));
}
self
}
/// Register a middleware.
pub fn middleware<M: Middleware<S>>(mut self, mw: M) -> App<S> {
self.parts.as_mut().expect("Use after finish")
.middlewares.push(Box::new(mw));
self.parts
.as_mut()
.expect("Use after finish")
.middlewares
.push(Box::new(mw));
self
}
@ -468,7 +485,8 @@ impl<S> App<S> where S: 'static {
/// }
/// ```
pub fn configure<F>(self, cfg: F) -> App<S>
where F: Fn(App<S>) -> App<S>
where
F: Fn(App<S>) -> App<S>,
{
cfg(self)
}
@ -490,15 +508,13 @@ impl<S> App<S> where S: 'static {
let (router, resources) = Router::new(&prefix, parts.settings, resources);
let inner = Rc::new(UnsafeCell::new(
Inner {
prefix: prefix_len,
default: parts.default,
encoding: parts.encoding,
handlers: parts.handlers,
resources,
}
));
let inner = Rc::new(UnsafeCell::new(Inner {
prefix: prefix_len,
default: parts.default,
encoding: parts.encoding,
handlers: parts.handlers,
resources,
}));
HttpApplication {
state: Rc::new(parts.state),
@ -582,14 +598,13 @@ impl<S: 'static> Iterator for App<S> {
}
}
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::*;
use test::TestRequest;
use http::StatusCode;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use test::TestRequest;
#[test]
fn test_default_resource() {
@ -603,14 +618,20 @@ mod tests {
let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let mut app = App::new()
.default_resource(|r| r.f(|_| HttpResponse::MethodNotAllowed()))
.finish();
let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::METHOD_NOT_ALLOWED
);
}
#[test]
@ -627,7 +648,8 @@ mod tests {
let mut app = App::with_state(10)
.resource("/", |r| r.f(|_| HttpResponse::Ok()))
.finish();
let req = HttpRequest::default().with_state(Rc::clone(&app.state), app.router.clone());
let req =
HttpRequest::default().with_state(Rc::clone(&app.state), app.router.clone());
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK);
}
@ -675,11 +697,17 @@ mod tests {
let req = TestRequest::with_uri("/testapp").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
}
#[test]
@ -702,11 +730,17 @@ mod tests {
let req = TestRequest::with_uri("/testapp").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
}
#[test]
@ -730,31 +764,53 @@ mod tests {
let req = TestRequest::with_uri("/prefix/testapp").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/prefix/blah").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
}
#[test]
fn test_route() {
let mut app = App::new()
.route("/test", Method::GET, |_: HttpRequest| HttpResponse::Ok())
.route("/test", Method::POST, |_: HttpRequest| HttpResponse::Created())
.route("/test", Method::GET, |_: HttpRequest| {
HttpResponse::Ok()
})
.route("/test", Method::POST, |_: HttpRequest| {
HttpResponse::Created()
})
.finish();
let req = TestRequest::with_uri("/test").method(Method::GET).finish();
let req = TestRequest::with_uri("/test")
.method(Method::GET)
.finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK);
let req = TestRequest::with_uri("/test").method(Method::POST).finish();
let req = TestRequest::with_uri("/test")
.method(Method::POST)
.finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::CREATED);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::CREATED
);
let req = TestRequest::with_uri("/test").method(Method::HEAD).finish();
let req = TestRequest::with_uri("/test")
.method(Method::HEAD)
.finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
}
#[test]
@ -766,7 +822,10 @@ mod tests {
let req = TestRequest::with_uri("/test").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/app/test").finish();
let resp = app.run(req);
@ -782,12 +841,16 @@ mod tests {
let req = TestRequest::with_uri("/app/testapp").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/app/blah").finish();
let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
}
}

View File

@ -1,18 +1,17 @@
use std::{fmt, mem};
use std::rc::Rc;
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use futures::Stream;
use std::rc::Rc;
use std::sync::Arc;
use std::{fmt, mem};
use error::Error;
use context::ActorHttpContext;
use error::Error;
use handler::Responder;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
/// Type represent streaming body
pub type BodyStream = Box<Stream<Item=Bytes, Error=Error>>;
pub type BodyStream = Box<Stream<Item = Bytes, Error = Error>>;
/// Represents various types of http message body.
pub enum Body {
@ -50,7 +49,7 @@ impl Body {
pub fn is_streaming(&self) -> bool {
match *self {
Body::Streaming(_) | Body::Actor(_) => true,
_ => false
_ => false,
}
}
@ -59,7 +58,7 @@ impl Body {
pub fn is_binary(&self) -> bool {
match *self {
Body::Binary(_) => true,
_ => false
_ => false,
}
}
@ -96,7 +95,10 @@ impl fmt::Debug for Body {
}
}
impl<T> From<T> for Body where T: Into<Binary>{
impl<T> From<T> for Body
where
T: Into<Binary>,
{
fn from(b: T) -> Body {
Body::Binary(b.into())
}
@ -257,8 +259,8 @@ impl Responder for Binary {
fn respond_to(self, _: HttpRequest) -> Result<HttpResponse, Error> {
Ok(HttpResponse::Ok()
.content_type("application/octet-stream")
.body(self))
.content_type("application/octet-stream")
.body(self))
}
}
@ -349,7 +351,7 @@ mod tests {
#[test]
fn test_bytes_mut() {
let b = BytesMut::from("test");
let b = BytesMut::from("test");
assert_eq!(Binary::from(b.clone()).len(), 4);
assert_eq!(Binary::from(b).as_ref(), b"test");
}

View File

@ -1,36 +1,35 @@
use std::{fmt, mem, io, time};
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::net::Shutdown;
use std::time::{Duration, Instant};
use std::collections::{HashMap, VecDeque};
use std::net::Shutdown;
use std::rc::Rc;
use std::time::{Duration, Instant};
use std::{fmt, io, mem, time};
use actix::{fut, Actor, ActorFuture, Arbiter, Context, AsyncContext,
Recipient, Syn, Handler, Message, ActorResponse,
Supervised, ContextFutureSpawner};
use actix::registry::ArbiterService;
use actix::actors::{Connect as ResolveConnect, Connector, ConnectorError};
use actix::fut::WrapFuture;
use actix::actors::{Connector, ConnectorError, Connect as ResolveConnect};
use actix::registry::ArbiterService;
use actix::{fut, Actor, ActorFuture, ActorResponse, Arbiter, AsyncContext, Context,
ContextFutureSpawner, Handler, Message, Recipient, Supervised, Syn};
use http::{Uri, HttpTryFrom, Error as HttpError};
use futures::{Async, Future, Poll};
use futures::task::{Task, current as current_task};
use futures::task::{current as current_task, Task};
use futures::unsync::oneshot;
use tokio_io::{AsyncRead, AsyncWrite};
use futures::{Async, Future, Poll};
use http::{Error as HttpError, HttpTryFrom, Uri};
use tokio_core::reactor::Timeout;
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature="alpn")]
use openssl::ssl::{SslMethod, SslConnector, Error as OpensslError};
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
use openssl::ssl::{Error as OpensslError, SslConnector, SslMethod};
#[cfg(feature = "alpn")]
use tokio_openssl::SslConnectorExt;
#[cfg(all(feature="tls", not(feature="alpn")))]
use native_tls::{TlsConnector, Error as TlsError};
#[cfg(all(feature="tls", not(feature="alpn")))]
#[cfg(all(feature = "tls", not(feature = "alpn")))]
use native_tls::{Error as TlsError, TlsConnector};
#[cfg(all(feature = "tls", not(feature = "alpn")))]
use tokio_tls::TlsConnectorExt;
use {HAS_OPENSSL, HAS_TLS};
use server::IoStream;
use {HAS_OPENSSL, HAS_TLS};
/// Client connector usage stats
#[derive(Default, Message)]
@ -54,7 +53,10 @@ pub struct Connect {
impl Connect {
/// Create `Connect` message for specified `Uri`
pub fn new<U>(uri: U) -> Result<Connect, HttpError> where Uri: HttpTryFrom<U> {
pub fn new<U>(uri: U) -> Result<Connect, HttpError>
where
Uri: HttpTryFrom<U>,
{
Ok(Connect {
uri: Uri::try_from(uri).map_err(|e| e.into())?,
wait_timeout: Duration::from_secs(5),
@ -92,13 +94,13 @@ pub struct Pause {
impl Pause {
/// Create message with pause duration parameter
pub fn new(time: Duration) -> Pause {
Pause{time: Some(time)}
Pause { time: Some(time) }
}
}
impl Default for Pause {
fn default() -> Pause {
Pause{time: None}
Pause { time: None }
}
}
@ -114,21 +116,21 @@ pub struct Resume;
#[derive(Fail, Debug)]
pub enum ClientConnectorError {
/// Invalid URL
#[fail(display="Invalid URL")]
#[fail(display = "Invalid URL")]
InvalidUrl,
/// SSL feature is not enabled
#[fail(display="SSL is not supported")]
#[fail(display = "SSL is not supported")]
SslIsNotSupported,
/// SSL error
#[cfg(feature="alpn")]
#[fail(display="{}", _0)]
#[cfg(feature = "alpn")]
#[fail(display = "{}", _0)]
SslError(#[cause] OpensslError),
/// SSL error
#[cfg(all(feature="tls", not(feature="alpn")))]
#[fail(display="{}", _0)]
#[cfg(all(feature = "tls", not(feature = "alpn")))]
#[fail(display = "{}", _0)]
SslError(#[cause] TlsError),
/// Connection error
@ -152,7 +154,7 @@ impl From<ConnectorError> for ClientConnectorError {
fn from(err: ConnectorError) -> ClientConnectorError {
match err {
ConnectorError::Timeout => ClientConnectorError::Timeout,
_ => ClientConnectorError::Connector(err)
_ => ClientConnectorError::Connector(err),
}
}
}
@ -166,9 +168,9 @@ struct Waiter {
/// `ClientConnector` type is responsible for transport layer of a
/// client connection.
pub struct ClientConnector {
#[cfg(all(feature="alpn"))]
#[cfg(all(feature = "alpn"))]
connector: SslConnector,
#[cfg(all(feature="tls", not(feature="alpn")))]
#[cfg(all(feature = "tls", not(feature = "alpn")))]
connector: TlsConnector,
stats: ClientConnectorStats,
@ -207,12 +209,12 @@ impl Default for ClientConnector {
fn default() -> ClientConnector {
let _modified = Rc::new(Cell::new(false));
#[cfg(all(feature="alpn"))]
#[cfg(all(feature = "alpn"))]
{
let builder = SslConnector::builder(SslMethod::tls()).unwrap();
ClientConnector::with_connector(builder.build())
}
#[cfg(all(feature="tls", not(feature="alpn")))]
#[cfg(all(feature = "tls", not(feature = "alpn")))]
{
let builder = TlsConnector::builder().unwrap();
ClientConnector {
@ -235,29 +237,29 @@ impl Default for ClientConnector {
}
}
#[cfg(not(any(feature="alpn", feature="tls")))]
ClientConnector {stats: ClientConnectorStats::default(),
subscriber: None,
pool: Rc::new(Pool::new(Rc::clone(&_modified))),
pool_modified: _modified,
conn_lifetime: Duration::from_secs(15),
conn_keep_alive: Duration::from_secs(75),
limit: 100,
limit_per_host: 0,
acquired: 0,
acquired_per_host: HashMap::new(),
available: HashMap::new(),
to_close: Vec::new(),
waiters: HashMap::new(),
wait_timeout: None,
paused: None,
#[cfg(not(any(feature = "alpn", feature = "tls")))]
ClientConnector {
stats: ClientConnectorStats::default(),
subscriber: None,
pool: Rc::new(Pool::new(Rc::clone(&_modified))),
pool_modified: _modified,
conn_lifetime: Duration::from_secs(15),
conn_keep_alive: Duration::from_secs(75),
limit: 100,
limit_per_host: 0,
acquired: 0,
acquired_per_host: HashMap::new(),
available: HashMap::new(),
to_close: Vec::new(),
waiters: HashMap::new(),
wait_timeout: None,
paused: None,
}
}
}
impl ClientConnector {
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
/// Create `ClientConnector` actor with custom `SslConnector` instance.
///
/// By default `ClientConnector` uses very a simple SSL configuration.
@ -369,20 +371,19 @@ impl ClientConnector {
// check limits
if self.limit > 0 {
if self.acquired >= self.limit {
return Acquire::NotAvailable
return Acquire::NotAvailable;
}
if self.limit_per_host > 0 {
if let Some(per_host) = self.acquired_per_host.get(key) {
if self.limit_per_host >= *per_host {
return Acquire::NotAvailable
return Acquire::NotAvailable;
}
}
}
}
else if self.limit_per_host > 0 {
} else if self.limit_per_host > 0 {
if let Some(per_host) = self.acquired_per_host.get(key) {
if self.limit_per_host >= *per_host {
return Acquire::NotAvailable
return Acquire::NotAvailable;
}
}
}
@ -408,11 +409,11 @@ impl ClientConnector {
Ok(n) if n > 0 => {
self.stats.closed += 1;
self.to_close.push(conn);
continue
},
continue;
}
Ok(_) | Err(_) => continue,
}
return Acquire::Acquired(conn)
return Acquire::Acquired(conn);
}
}
}
@ -421,25 +422,25 @@ impl ClientConnector {
fn reserve(&mut self, key: &Key) {
self.acquired += 1;
let per_host =
if let Some(per_host) = self.acquired_per_host.get(key) {
*per_host
} else {
0
};
self.acquired_per_host.insert(key.clone(), per_host + 1);
let per_host = if let Some(per_host) = self.acquired_per_host.get(key) {
*per_host
} else {
0
};
self.acquired_per_host
.insert(key.clone(), per_host + 1);
}
fn release_key(&mut self, key: &Key) {
self.acquired -= 1;
let per_host =
if let Some(per_host) = self.acquired_per_host.get(key) {
*per_host
} else {
return
};
let per_host = if let Some(per_host) = self.acquired_per_host.get(key) {
*per_host
} else {
return;
};
if per_host > 1 {
self.acquired_per_host.insert(key.clone(), per_host - 1);
self.acquired_per_host
.insert(key.clone(), per_host - 1);
} else {
self.acquired_per_host.remove(key);
}
@ -472,7 +473,8 @@ impl ClientConnector {
// check connection lifetime and the return to available pool
if (now - conn.ts) < self.conn_lifetime {
self.available.entry(conn.key.clone())
self.available
.entry(conn.key.clone())
.or_insert_with(VecDeque::new)
.push_back(Conn(Instant::now(), conn));
}
@ -490,7 +492,7 @@ impl ClientConnector {
self.to_close.push(conn);
self.stats.closed += 1;
} else {
break
break;
}
}
}
@ -503,7 +505,7 @@ impl ClientConnector {
Ok(Async::NotReady) => idx += 1,
_ => {
self.to_close.swap_remove(idx);
},
}
}
}
}
@ -514,7 +516,9 @@ impl ClientConnector {
fn collect_periodic(&mut self, ctx: &mut Context<Self>) {
self.collect(true);
// re-schedule next collect period
ctx.run_later(Duration::from_secs(1), |act, ctx| act.collect_periodic(ctx));
ctx.run_later(Duration::from_secs(1), |act, ctx| {
act.collect_periodic(ctx)
});
// send stats
let stats = mem::replace(&mut self.stats, ClientConnectorStats::default());
@ -555,27 +559,34 @@ impl ClientConnector {
fn install_wait_timeout(&mut self, time: Instant) {
if let Some(ref mut wait) = self.wait_timeout {
if wait.0 < time {
return
return;
}
}
let mut timeout = Timeout::new(time-Instant::now(), Arbiter::handle()).unwrap();
let mut timeout =
Timeout::new(time - Instant::now(), Arbiter::handle()).unwrap();
let _ = timeout.poll();
self.wait_timeout = Some((time, timeout));
}
fn wait_for(&mut self, key: Key,
wait: Duration, conn_timeout: Duration)
-> oneshot::Receiver<Result<Connection, ClientConnectorError>>
{
fn wait_for(
&mut self, key: Key, wait: Duration, conn_timeout: Duration
) -> oneshot::Receiver<Result<Connection, ClientConnectorError>> {
// connection is not available, wait
let (tx, rx) = oneshot::channel();
let wait = Instant::now() + wait;
self.install_wait_timeout(wait);
let waiter = Waiter{ tx, wait, conn_timeout };
self.waiters.entry(key).or_insert_with(VecDeque::new).push_back(waiter);
let waiter = Waiter {
tx,
wait,
conn_timeout,
};
self.waiters
.entry(key)
.or_insert_with(VecDeque::new)
.push_back(waiter);
rx
}
}
@ -617,21 +628,23 @@ impl Handler<Connect> for ClientConnector {
// host name is required
if uri.host().is_none() {
return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl))
return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl));
}
// supported protocols
let proto = match uri.scheme_part() {
Some(scheme) => match Protocol::from(scheme.as_str()) {
Some(proto) => proto,
None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)),
None => {
return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl))
}
},
None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)),
};
// check ssl availability
if proto.is_secure() && !HAS_OPENSSL && !HAS_TLS {
return ActorResponse::reply(Err(ClientConnectorError::SslIsNotSupported))
return ActorResponse::reply(Err(ClientConnectorError::SslIsNotSupported));
}
// check if pool has task reference
@ -641,7 +654,11 @@ impl Handler<Connect> for ClientConnector {
let host = uri.host().unwrap().to_owned();
let port = uri.port().unwrap_or_else(|| proto.port());
let key = Key {host, port, ssl: proto.is_secure()};
let key = Key {
host,
port,
ssl: proto.is_secure(),
};
// check pause state
if self.paused.is_some() {
@ -653,7 +670,8 @@ impl Handler<Connect> for ClientConnector {
.and_then(|res, _, _| match res {
Ok(conn) => fut::ok(conn),
Err(err) => fut::err(err),
}));
}),
);
}
// acquire connection
@ -663,8 +681,8 @@ impl Handler<Connect> for ClientConnector {
// use existing connection
conn.pool = Some(AcquiredConn(key, Some(Rc::clone(&self.pool))));
self.stats.reused += 1;
return ActorResponse::async(fut::ok(conn))
},
return ActorResponse::async(fut::ok(conn));
}
Acquire::NotAvailable => {
// connection is not available, wait
let rx = self.wait_for(key, wait_timeout, conn_timeout);
@ -675,106 +693,131 @@ impl Handler<Connect> for ClientConnector {
.and_then(|res, _, _| match res {
Ok(conn) => fut::ok(conn),
Err(err) => fut::err(err),
}));
}),
);
}
Acquire::Available => {
Some(Rc::clone(&self.pool))
},
Acquire::Available => Some(Rc::clone(&self.pool)),
}
} else {
None
};
let conn = AcquiredConn(key, pool);
{
ActorResponse::async(
Connector::from_registry()
.send(ResolveConnect::host_and_port(&conn.0.host, port)
.timeout(conn_timeout))
.into_actor(self)
.map_err(|_, _, _| ClientConnectorError::Disconnected)
.and_then(move |res, act, _| {
#[cfg(feature="alpn")]
match res {
Err(err) => {
act.stats.opened += 1;
fut::Either::B(fut::err(err.into()))
},
Ok(stream) => {
act.stats.opened += 1;
if proto.is_secure() {
fut::Either::A(
act.connector.connect_async(&conn.0.host, stream)
.map_err(ClientConnectorError::SslError)
.map(|stream| Connection::new(
conn.0.clone(), Some(conn), Box::new(stream)))
.into_actor(act))
} else {
fut::Either::B(fut::ok(
Connection::new(
conn.0.clone(), Some(conn), Box::new(stream))))
{
ActorResponse::async(
Connector::from_registry()
.send(
ResolveConnect::host_and_port(&conn.0.host, port)
.timeout(conn_timeout),
)
.into_actor(self)
.map_err(|_, _, _| ClientConnectorError::Disconnected)
.and_then(move |res, act, _| {
#[cfg(feature = "alpn")]
match res {
Err(err) => {
act.stats.opened += 1;
fut::Either::B(fut::err(err.into()))
}
Ok(stream) => {
act.stats.opened += 1;
if proto.is_secure() {
fut::Either::A(
act.connector
.connect_async(&conn.0.host, stream)
.map_err(ClientConnectorError::SslError)
.map(|stream| {
Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)
})
.into_actor(act),
)
} else {
fut::Either::B(fut::ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)))
}
}
}
}
}
#[cfg(all(feature="tls", not(feature="alpn")))]
match res {
Err(err) => {
act.stats.opened += 1;
fut::Either::B(fut::err(err.into()))
},
Ok(stream) => {
act.stats.opened += 1;
if proto.is_secure() {
fut::Either::A(
act.connector.connect_async(&conn.0.host, stream)
.map_err(ClientConnectorError::SslError)
.map(|stream| Connection::new(
conn.0.clone(), Some(conn), Box::new(stream)))
.into_actor(act))
} else {
fut::Either::B(fut::ok(
Connection::new(
conn.0.clone(), Some(conn), Box::new(stream))))
#[cfg(all(feature = "tls", not(feature = "alpn")))]
match res {
Err(err) => {
act.stats.opened += 1;
fut::Either::B(fut::err(err.into()))
}
Ok(stream) => {
act.stats.opened += 1;
if proto.is_secure() {
fut::Either::A(
act.connector
.connect_async(&conn.0.host, stream)
.map_err(ClientConnectorError::SslError)
.map(|stream| {
Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)
})
.into_actor(act),
)
} else {
fut::Either::B(fut::ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)))
}
}
}
}
}
#[cfg(not(any(feature="alpn", feature="tls")))]
match res {
Err(err) => {
act.stats.opened += 1;
fut::err(err.into())
},
Ok(stream) => {
act.stats.opened += 1;
if proto.is_secure() {
fut::err(ClientConnectorError::SslIsNotSupported)
} else {
fut::ok(Connection::new(
conn.0.clone(), Some(conn), Box::new(stream)))
#[cfg(not(any(feature = "alpn", feature = "tls")))]
match res {
Err(err) => {
act.stats.opened += 1;
fut::err(err.into())
}
Ok(stream) => {
act.stats.opened += 1;
if proto.is_secure() {
fut::err(ClientConnectorError::SslIsNotSupported)
} else {
fut::ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
))
}
}
}
}
}
}))
}
}),
)
}
}
}
struct Maintenance;
impl fut::ActorFuture for Maintenance
{
impl fut::ActorFuture for Maintenance {
type Item = ();
type Error = ();
type Actor = ClientConnector;
fn poll(&mut self, act: &mut ClientConnector, ctx: &mut Context<ClientConnector>)
-> Poll<Self::Item, Self::Error>
{
fn poll(
&mut self, act: &mut ClientConnector, ctx: &mut Context<ClientConnector>
) -> Poll<Self::Item, Self::Error> {
// check pause duration
let done = if let Some(Some(ref pause)) = act.paused {
pause.0 <= Instant::now() } else { false };
pause.0 <= Instant::now()
} else {
false
};
if done {
act.paused.take();
}
@ -788,128 +831,151 @@ impl fut::ActorFuture for Maintenance
act.collect_waiters();
// check waiters
let tmp: &mut ClientConnector = unsafe{mem::transmute(act as &mut _)};
let tmp: &mut ClientConnector = unsafe { mem::transmute(act as &mut _) };
for (key, waiters) in &mut tmp.waiters {
while let Some(waiter) = waiters.pop_front() {
if waiter.tx.is_canceled() { continue }
if waiter.tx.is_canceled() {
continue;
}
match act.acquire(key) {
Acquire::Acquired(mut conn) => {
// use existing connection
act.stats.reused += 1;
conn.pool = Some(
AcquiredConn(key.clone(), Some(Rc::clone(&act.pool))));
conn.pool =
Some(AcquiredConn(key.clone(), Some(Rc::clone(&act.pool))));
let _ = waiter.tx.send(Ok(conn));
},
}
Acquire::NotAvailable => {
waiters.push_front(waiter);
break
break;
}
Acquire::Available =>
{
let conn = AcquiredConn(key.clone(), Some(Rc::clone(&act.pool)));
Acquire::Available => {
let conn = AcquiredConn(key.clone(), Some(Rc::clone(&act.pool)));
fut::WrapFuture::<ClientConnector>::actfuture(
Connector::from_registry()
.send(ResolveConnect::host_and_port(&conn.0.host, conn.0.port)
.timeout(waiter.conn_timeout)))
.map_err(|_, _, _| ())
.and_then(move |res, act, _| {
#[cfg(feature="alpn")]
match res {
Err(err) => {
act.stats.errors += 1;
let _ = waiter.tx.send(Err(err.into()));
fut::Either::B(fut::err(()))
},
Ok(stream) => {
act.stats.opened += 1;
if conn.0.ssl {
fut::Either::A(
act.connector.connect_async(&key.host, stream)
.then(move |res| {
match res {
Err(e) => {
let _ = waiter.tx.send(Err(
ClientConnectorError::SslError(e)));
},
Ok(stream) => {
let _ = waiter.tx.send(Ok(
Connection::new(
conn.0.clone(),
Some(conn), Box::new(stream))));
}
}
Ok(())
})
.actfuture())
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(), Some(conn), Box::new(stream))));
fut::Either::B(fut::ok(()))
}
}
}
fut::WrapFuture::<ClientConnector>::actfuture(
Connector::from_registry().send(
ResolveConnect::host_and_port(&conn.0.host, conn.0.port)
.timeout(waiter.conn_timeout),
),
).map_err(|_, _, _| ())
.and_then(move |res, act, _| {
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(feature = "alpn")]
match res {
Err(err) => {
act.stats.errors += 1;
let _ = waiter.tx.send(Err(err.into()));
fut::Either::B(fut::err(()))
}
Ok(stream) => {
act.stats.opened += 1;
if conn.0.ssl {
fut::Either::A(
act.connector
.connect_async(&key.host, stream)
.then(move |res| {
match res {
Err(e) => {
let _ = waiter.tx.send(
Err(ClientConnectorError::SslError(e)));
}
Ok(stream) => {
let _ = waiter.tx.send(
Ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)),
);
}
}
Ok(())
})
.actfuture(),
)
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)));
fut::Either::B(fut::ok(()))
}
}
}
#[cfg(all(feature="tls", not(feature="alpn")))]
match res {
Err(err) => {
act.stats.errors += 1;
let _ = waiter.tx.send(Err(err.into()));
fut::Either::B(fut::err(()))
},
Ok(stream) => {
act.stats.opened += 1;
if conn.0.ssl {
fut::Either::A(
act.connector.connect_async(&conn.0.host, stream)
.then(|res| {
match res {
Err(e) => {
let _ = waiter.tx.send(Err(
ClientConnectorError::SslError(e)));
},
Ok(stream) => {
let _ = waiter.tx.send(Ok(
Connection::new(
conn.0.clone(), Some(conn),
Box::new(stream))));
}
}
Ok(())
})
.into_actor(act))
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(), Some(conn), Box::new(stream))));
fut::Either::B(fut::ok(()))
}
}
}
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(all(feature = "tls", not(feature = "alpn")))]
match res {
Err(err) => {
act.stats.errors += 1;
let _ = waiter.tx.send(Err(err.into()));
fut::Either::B(fut::err(()))
}
Ok(stream) => {
act.stats.opened += 1;
if conn.0.ssl {
fut::Either::A(
act.connector
.connect_async(&conn.0.host, stream)
.then(|res| {
match res {
Err(e) => {
let _ = waiter.tx.send(Err(
ClientConnectorError::SslError(e),
));
}
Ok(stream) => {
let _ = waiter.tx.send(
Ok(Connection::new(
conn.0.clone(), Some(conn),
Box::new(stream),
)),
);
}
}
Ok(())
})
.into_actor(act),
)
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)));
fut::Either::B(fut::ok(()))
}
}
}
#[cfg(not(any(feature="alpn", feature="tls")))]
match res {
Err(err) => {
act.stats.errors += 1;
let _ = waiter.tx.send(Err(err.into()));
fut::err(())
},
Ok(stream) => {
act.stats.opened += 1;
if conn.0.ssl {
let _ = waiter.tx.send(
Err(ClientConnectorError::SslIsNotSupported));
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(), Some(conn), Box::new(stream))));
};
fut::ok(())
},
}
})
.spawn(ctx);
}
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(not(any(feature = "alpn", feature = "tls")))]
match res {
Err(err) => {
act.stats.errors += 1;
let _ = waiter.tx.send(Err(err.into()));
fut::err(())
}
Ok(stream) => {
act.stats.opened += 1;
if conn.0.ssl {
let _ = waiter.tx.send(Err(ClientConnectorError::SslIsNotSupported));
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)));
};
fut::ok(())
}
}
})
.spawn(ctx);
}
}
}
}
@ -954,7 +1020,7 @@ impl Protocol {
fn port(&self) -> u16 {
match *self {
Protocol::Http | Protocol::Ws => 80,
Protocol::Https | Protocol::Wss => 443
Protocol::Https | Protocol::Wss => 443,
}
}
}
@ -968,7 +1034,11 @@ struct Key {
impl Key {
fn empty() -> Key {
Key{host: String::new(), port: 0, ssl: false}
Key {
host: String::new(),
port: 0,
ssl: false,
}
}
}
@ -1035,7 +1105,10 @@ impl Pool {
if self.to_close.borrow().is_empty() {
None
} else {
Some(mem::replace(&mut *self.to_close.borrow_mut(), Vec::new()))
Some(mem::replace(
&mut *self.to_close.borrow_mut(),
Vec::new(),
))
}
}
@ -1043,7 +1116,10 @@ impl Pool {
if self.to_release.borrow().is_empty() {
None
} else {
Some(mem::replace(&mut *self.to_release.borrow_mut(), Vec::new()))
Some(mem::replace(
&mut *self.to_release.borrow_mut(),
Vec::new(),
))
}
}
@ -1072,7 +1148,6 @@ impl Pool {
}
}
pub struct Connection {
key: Key,
stream: Box<IoStream>,
@ -1088,7 +1163,12 @@ impl fmt::Debug for Connection {
impl Connection {
fn new(key: Key, pool: Option<AcquiredConn>, stream: Box<IoStream>) -> Self {
Connection {key, stream, pool, ts: Instant::now()}
Connection {
key,
stream,
pool,
ts: Instant::now(),
}
}
pub fn stream(&mut self) -> &mut IoStream {

View File

@ -28,33 +28,30 @@
//! ```
mod connector;
mod parser;
mod pipeline;
mod request;
mod response;
mod pipeline;
mod writer;
pub use self::connector::{ClientConnector, ClientConnectorError, ClientConnectorStats,
Connect, Connection, Pause, Resume};
pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError};
pub use self::pipeline::{SendRequest, SendRequestError};
pub use self::request::{ClientRequest, ClientRequestBuilder};
pub use self::response::ClientResponse;
pub use self::connector::{
Connect, Pause, Resume,
Connection, ClientConnector, ClientConnectorError, ClientConnectorStats};
pub(crate) use self::writer::HttpClientWriter;
pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError};
use error::ResponseError;
use http::Method;
use httpresponse::HttpResponse;
/// Convert `SendRequestError` to a `HttpResponse`
impl ResponseError for SendRequestError {
fn error_response(&self) -> HttpResponse {
match *self {
SendRequestError::Connector(_) => HttpResponse::BadGateway(),
_ => HttpResponse::InternalServerError(),
}
.into()
}.into()
}
}

View File

@ -1,14 +1,14 @@
use std::mem;
use httparse;
use http::{Version, HttpTryFrom, HeaderMap, StatusCode};
use http::header::{self, HeaderName, HeaderValue};
use bytes::{Bytes, BytesMut};
use futures::{Poll, Async};
use futures::{Async, Poll};
use http::header::{self, HeaderName, HeaderValue};
use http::{HeaderMap, HttpTryFrom, StatusCode, Version};
use httparse;
use std::mem;
use error::{ParseError, PayloadError};
use server::h1::{chunked, Decoder};
use server::{utils, IoStream};
use server::h1::{Decoder, chunked};
use super::ClientResponse;
use super::response::ClientMessage;
@ -24,28 +24,26 @@ pub struct HttpResponseParser {
#[derive(Debug, Fail)]
pub enum HttpResponseParserError {
/// Server disconnected
#[fail(display="Server disconnected")]
#[fail(display = "Server disconnected")]
Disconnect,
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Error(#[cause] ParseError),
}
impl HttpResponseParser {
pub fn parse<T>(&mut self, io: &mut T, buf: &mut BytesMut)
-> Poll<ClientResponse, HttpResponseParserError>
where T: IoStream
pub fn parse<T>(
&mut self, io: &mut T, buf: &mut BytesMut
) -> Poll<ClientResponse, HttpResponseParserError>
where
T: IoStream,
{
// if buf is empty parse_message will always return NotReady, let's avoid that
if buf.is_empty() {
match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) =>
return Err(HttpResponseParserError::Disconnect),
Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect),
Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) =>
return Ok(Async::NotReady),
Err(err) =>
return Err(HttpResponseParserError::Error(err.into()))
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(HttpResponseParserError::Error(err.into())),
}
}
@ -56,27 +54,31 @@ impl HttpResponseParser {
Async::Ready((msg, decoder)) => {
self.decoder = decoder;
return Ok(Async::Ready(msg));
},
}
Async::NotReady => {
if buf.capacity() >= MAX_BUFFER_SIZE {
return Err(HttpResponseParserError::Error(ParseError::TooLarge));
}
match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) =>
return Err(HttpResponseParserError::Disconnect),
Ok(Async::Ready(0)) => {
return Err(HttpResponseParserError::Disconnect)
}
Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) =>
return Err(HttpResponseParserError::Error(err.into())),
Err(err) => {
return Err(HttpResponseParserError::Error(err.into()))
}
}
},
}
}
}
}
pub fn parse_payload<T>(&mut self, io: &mut T, buf: &mut BytesMut)
-> Poll<Option<Bytes>, PayloadError>
where T: IoStream
pub fn parse_payload<T>(
&mut self, io: &mut T, buf: &mut BytesMut
) -> Poll<Option<Bytes>, PayloadError>
where
T: IoStream,
{
if self.decoder.is_some() {
loop {
@ -89,18 +91,17 @@ impl HttpResponseParser {
};
match self.decoder.as_mut().unwrap().decode(buf) {
Ok(Async::Ready(Some(b))) =>
return Ok(Async::Ready(Some(b))),
Ok(Async::Ready(Some(b))) => return Ok(Async::Ready(Some(b))),
Ok(Async::Ready(None)) => {
self.decoder.take();
return Ok(Async::Ready(None))
return Ok(Async::Ready(None));
}
Ok(Async::NotReady) => {
if not_ready {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
if stream_finished {
return Err(PayloadError::Incomplete)
return Err(PayloadError::Incomplete);
}
}
Err(err) => return Err(err.into()),
@ -111,16 +112,19 @@ impl HttpResponseParser {
}
}
fn parse_message(buf: &mut BytesMut)
-> Poll<(ClientResponse, Option<Decoder>), ParseError>
{
fn parse_message(
buf: &mut BytesMut
) -> Poll<(ClientResponse, Option<Decoder>), ParseError> {
// Parse http message
let bytes_ptr = buf.as_ref().as_ptr() as usize;
let mut headers: [httparse::Header; MAX_HEADERS] =
unsafe{mem::uninitialized()};
unsafe { mem::uninitialized() };
let (len, version, status, headers_len) = {
let b = unsafe{ let b: &[u8] = buf; mem::transmute(b) };
let b = unsafe {
let b: &[u8] = buf;
mem::transmute(b)
};
let mut resp = httparse::Response::new(&mut headers);
match resp.parse(b)? {
httparse::Status::Complete(len) => {
@ -147,10 +151,11 @@ impl HttpResponseParser {
let v_start = header.value.as_ptr() as usize - bytes_ptr;
let v_end = v_start + header.value.len();
let value = unsafe {
HeaderValue::from_shared_unchecked(slice.slice(v_start, v_end)) };
HeaderValue::from_shared_unchecked(slice.slice(v_start, v_end))
};
hdrs.append(name, value);
} else {
return Err(ParseError::Header)
return Err(ParseError::Header);
}
}
@ -163,11 +168,11 @@ impl HttpResponseParser {
Some(Decoder::length(len))
} else {
debug!("illegal Content-Length: {:?}", len);
return Err(ParseError::Header)
return Err(ParseError::Header);
}
} else {
debug!("illegal Content-Length: {:?}", len);
return Err(ParseError::Header)
return Err(ParseError::Header);
}
} else if chunked(&hdrs)? {
// Chunked encoding
@ -177,15 +182,25 @@ impl HttpResponseParser {
};
if let Some(decoder) = decoder {
Ok(Async::Ready(
(ClientResponse::new(
ClientMessage{status, version,
headers: hdrs, cookies: None}), Some(decoder))))
Ok(Async::Ready((
ClientResponse::new(ClientMessage {
status,
version,
headers: hdrs,
cookies: None,
}),
Some(decoder),
)))
} else {
Ok(Async::Ready(
(ClientResponse::new(
ClientMessage{status, version,
headers: hdrs, cookies: None}), None)))
Ok(Async::Ready((
ClientResponse::new(ClientMessage {
status,
version,
headers: hdrs,
cookies: None,
}),
None,
)))
}
}
}

View File

@ -1,26 +1,26 @@
use std::{io, mem};
use std::time::Duration;
use bytes::{Bytes, BytesMut};
use http::header::CONTENT_ENCODING;
use futures::{Async, Future, Poll};
use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use http::header::CONTENT_ENCODING;
use std::time::Duration;
use std::{io, mem};
use tokio_core::reactor::Timeout;
use actix::prelude::*;
use error::Error;
use super::HttpClientWriter;
use super::{ClientConnector, ClientConnectorError, Connect, Connection};
use super::{ClientRequest, ClientResponse};
use super::{HttpResponseParser, HttpResponseParserError};
use body::{Body, BodyStream};
use context::{Frame, ActorHttpContext};
use context::{ActorHttpContext, Frame};
use error::Error;
use error::PayloadError;
use header::ContentEncoding;
use httpmessage::HttpMessage;
use error::PayloadError;
use server::WriterState;
use server::shared::SharedBytes;
use server::encoding::PayloadStream;
use super::{ClientRequest, ClientResponse};
use super::{Connect, Connection, ClientConnector, ClientConnectorError};
use super::HttpClientWriter;
use super::{HttpResponseParser, HttpResponseParserError};
use server::shared::SharedBytes;
/// A set of errors that can occur during request sending and response reading
#[derive(Fail, Debug)]
@ -29,13 +29,13 @@ pub enum SendRequestError {
#[fail(display = "Timeout while waiting for response")]
Timeout,
/// Failed to connect to host
#[fail(display="Failed to connect to host: {}", _0)]
#[fail(display = "Failed to connect to host: {}", _0)]
Connector(#[cause] ClientConnectorError),
/// Error parsing response
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
ParseError(#[cause] HttpResponseParserError),
/// Error reading response payload
#[fail(display="Error reading response payload: {}", _0)]
#[fail(display = "Error reading response payload: {}", _0)]
Io(#[cause] io::Error),
}
@ -79,25 +79,27 @@ impl SendRequest {
SendRequest::with_connector(req, ClientConnector::from_registry())
}
pub(crate) fn with_connector(req: ClientRequest, conn: Addr<Unsync, ClientConnector>)
-> SendRequest
{
SendRequest{req, conn,
state: State::New,
timeout: None,
wait_timeout: Duration::from_secs(5),
conn_timeout: Duration::from_secs(1),
pub(crate) fn with_connector(
req: ClientRequest, conn: Addr<Unsync, ClientConnector>
) -> SendRequest {
SendRequest {
req,
conn,
state: State::New,
timeout: None,
wait_timeout: Duration::from_secs(5),
conn_timeout: Duration::from_secs(1),
}
}
pub(crate) fn with_connection(req: ClientRequest, conn: Connection) -> SendRequest
{
SendRequest{req,
state: State::Connection(conn),
conn: ClientConnector::from_registry(),
timeout: None,
wait_timeout: Duration::from_secs(5),
conn_timeout: Duration::from_secs(1),
pub(crate) fn with_connection(req: ClientRequest, conn: Connection) -> SendRequest {
SendRequest {
req,
state: State::Connection(conn),
conn: ClientConnector::from_registry(),
timeout: None,
wait_timeout: Duration::from_secs(5),
conn_timeout: Duration::from_secs(1),
}
}
@ -139,25 +141,27 @@ impl Future for SendRequest {
let state = mem::replace(&mut self.state, State::None);
match state {
State::New =>
State::New => {
self.state = State::Connect(self.conn.send(Connect {
uri: self.req.uri().clone(),
wait_timeout: self.wait_timeout,
conn_timeout: self.conn_timeout,
})),
}))
}
State::Connect(mut conn) => match conn.poll() {
Ok(Async::NotReady) => {
self.state = State::Connect(conn);
return Ok(Async::NotReady);
},
}
Ok(Async::Ready(result)) => match result {
Ok(stream) => {
self.state = State::Connection(stream)
},
Ok(stream) => self.state = State::Connection(stream),
Err(err) => return Err(err.into()),
},
Err(_) => return Err(SendRequestError::Connector(
ClientConnectorError::Disconnected))
Err(_) => {
return Err(SendRequestError::Connector(
ClientConnectorError::Disconnected,
))
}
},
State::Connection(conn) => {
let mut writer = HttpClientWriter::new(SharedBytes::default());
@ -169,12 +173,13 @@ impl Future for SendRequest {
_ => IoBody::Done,
};
let timeout = self.timeout.take().unwrap_or_else(||
Timeout::new(
Duration::from_secs(5), Arbiter::handle()).unwrap());
let timeout = self.timeout.take().unwrap_or_else(|| {
Timeout::new(Duration::from_secs(5), Arbiter::handle()).unwrap()
});
let pl = Box::new(Pipeline {
body, writer,
body,
writer,
conn: Some(conn),
parser: Some(HttpResponseParser::default()),
parser_buf: BytesMut::new(),
@ -186,22 +191,22 @@ impl Future for SendRequest {
timeout: Some(timeout),
});
self.state = State::Send(pl);
},
}
State::Send(mut pl) => {
pl.poll_write()
.map_err(|e| io::Error::new(
io::ErrorKind::Other, format!("{}", e).as_str()))?;
pl.poll_write().map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("{}", e).as_str())
})?;
match pl.parse() {
Ok(Async::Ready(mut resp)) => {
resp.set_pipeline(pl);
return Ok(Async::Ready(resp))
},
return Ok(Async::Ready(resp));
}
Ok(Async::NotReady) => {
self.state = State::Send(pl);
return Ok(Async::NotReady)
},
Err(err) => return Err(SendRequestError::ParseError(err))
return Ok(Async::NotReady);
}
Err(err) => return Err(SendRequestError::ParseError(err)),
}
}
State::None => unreachable!(),
@ -210,7 +215,6 @@ impl Future for SendRequest {
}
}
pub(crate) struct Pipeline {
body: IoBody,
conn: Option<Connection>,
@ -254,7 +258,6 @@ impl RunningState {
}
impl Pipeline {
fn release_conn(&mut self) {
if let Some(conn) = self.conn.take() {
conn.release()
@ -264,15 +267,22 @@ impl Pipeline {
#[inline]
fn parse(&mut self) -> Poll<ClientResponse, HttpResponseParserError> {
if let Some(ref mut conn) = self.conn {
match self.parser.as_mut().unwrap().parse(conn, &mut self.parser_buf) {
match self.parser
.as_mut()
.unwrap()
.parse(conn, &mut self.parser_buf)
{
Ok(Async::Ready(resp)) => {
// check content-encoding
if self.should_decompress {
if let Some(enc) = resp.headers().get(CONTENT_ENCODING) {
if let Ok(enc) = enc.to_str() {
match ContentEncoding::from(enc) {
ContentEncoding::Auto | ContentEncoding::Identity => (),
enc => self.decompress = Some(PayloadStream::new(enc)),
ContentEncoding::Auto
| ContentEncoding::Identity => (),
enc => {
self.decompress = Some(PayloadStream::new(enc))
}
}
}
}
@ -290,9 +300,10 @@ impl Pipeline {
#[inline]
pub fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> {
if self.conn.is_none() {
return Ok(Async::Ready(None))
return Ok(Async::Ready(None));
}
let conn: &mut Connection = unsafe{ mem::transmute(self.conn.as_mut().unwrap())};
let conn: &mut Connection =
unsafe { mem::transmute(self.conn.as_mut().unwrap()) };
let mut need_run = false;
@ -302,15 +313,18 @@ impl Pipeline {
{
Async::NotReady => need_run = true,
Async::Ready(_) => {
let _ = self.poll_timeout()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
let _ = self.poll_timeout().map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("{}", e))
})?;
}
}
// need read?
if self.parser.is_some() {
loop {
match self.parser.as_mut().unwrap()
match self.parser
.as_mut()
.unwrap()
.parse_payload(conn, &mut self.parser_buf)?
{
Async::Ready(Some(b)) => {
@ -318,17 +332,20 @@ impl Pipeline {
match decompress.feed_data(b) {
Ok(Some(b)) => return Ok(Async::Ready(Some(b))),
Ok(None) => return Ok(Async::NotReady),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock =>
continue,
Err(ref err)
if err.kind() == io::ErrorKind::WouldBlock =>
{
continue
}
Err(err) => return Err(err.into()),
}
} else {
return Ok(Async::Ready(Some(b)))
return Ok(Async::Ready(Some(b)));
}
},
}
Async::Ready(None) => {
let _ = self.parser.take();
break
break;
}
Async::NotReady => return Ok(Async::NotReady),
}
@ -340,7 +357,7 @@ impl Pipeline {
let res = decompress.feed_eof();
if let Some(b) = res? {
self.release_conn();
return Ok(Async::Ready(Some(b)))
return Ok(Async::Ready(Some(b)));
}
}
@ -357,7 +374,7 @@ impl Pipeline {
match self.timeout.as_mut().unwrap().poll() {
Ok(Async::Ready(())) => Err(SendRequestError::Timeout),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => unreachable!()
Err(_) => unreachable!(),
}
} else {
Ok(Async::NotReady)
@ -367,29 +384,27 @@ impl Pipeline {
#[inline]
fn poll_write(&mut self) -> Poll<(), Error> {
if self.write_state == RunningState::Done || self.conn.is_none() {
return Ok(Async::Ready(()))
return Ok(Async::Ready(()));
}
let mut done = false;
if self.drain.is_none() && self.write_state != RunningState::Paused {
'outter: loop {
let result = match mem::replace(&mut self.body, IoBody::Done) {
IoBody::Payload(mut body) => {
match body.poll()? {
Async::Ready(None) => {
self.writer.write_eof()?;
self.disconnected = true;
break
},
Async::Ready(Some(chunk)) => {
self.body = IoBody::Payload(body);
self.writer.write(chunk.into())?
}
Async::NotReady => {
done = true;
self.body = IoBody::Payload(body);
break
},
IoBody::Payload(mut body) => match body.poll()? {
Async::Ready(None) => {
self.writer.write_eof()?;
self.disconnected = true;
break;
}
Async::Ready(Some(chunk)) => {
self.body = IoBody::Payload(body);
self.writer.write(chunk.into())?
}
Async::NotReady => {
done = true;
self.body = IoBody::Payload(body);
break;
}
},
IoBody::Actor(mut ctx) => {
@ -400,7 +415,7 @@ impl Pipeline {
Async::Ready(Some(vec)) => {
if vec.is_empty() {
self.body = IoBody::Actor(ctx);
break
break;
}
let mut res = None;
for frame in vec {
@ -409,52 +424,53 @@ impl Pipeline {
// info.context = Some(ctx);
self.disconnected = true;
self.writer.write_eof()?;
break 'outter
},
Frame::Chunk(Some(chunk)) =>
res = Some(self.writer.write(chunk)?),
break 'outter;
}
Frame::Chunk(Some(chunk)) => {
res = Some(self.writer.write(chunk)?)
}
Frame::Drain(fut) => self.drain = Some(fut),
}
}
self.body = IoBody::Actor(ctx);
if self.drain.is_some() {
self.write_state.resume();
break
break;
}
res.unwrap()
},
}
Async::Ready(None) => {
done = true;
break
break;
}
Async::NotReady => {
done = true;
self.body = IoBody::Actor(ctx);
break
break;
}
}
},
}
IoBody::Done => {
self.disconnected = true;
done = true;
break
break;
}
};
match result {
WriterState::Pause => {
self.write_state.pause();
break
break;
}
WriterState::Done => {
self.write_state.resume()
},
WriterState::Done => self.write_state.resume(),
}
}
}
// flush io but only if we need to
match self.writer.poll_completed(self.conn.as_mut().unwrap(), false) {
match self.writer
.poll_completed(self.conn.as_mut().unwrap(), false)
{
Ok(Async::Ready(_)) => {
if self.disconnected {
self.write_state = RunningState::Done;
@ -472,7 +488,7 @@ impl Pipeline {
} else {
Ok(Async::NotReady)
}
},
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => Err(err.into()),
}

View File

@ -1,26 +1,26 @@
use std::{fmt, mem};
use std::fmt::Write as FmtWrite;
use std::io::Write;
use std::time::Duration;
use std::{fmt, mem};
use actix::{Addr, Unsync};
use bytes::{BufMut, Bytes, BytesMut};
use cookie::{Cookie, CookieJar};
use bytes::{Bytes, BytesMut, BufMut};
use futures::Stream;
use serde_json;
use percent_encoding::{percent_encode, USERINFO_ENCODE_SET};
use serde::Serialize;
use serde_json;
use url::Url;
use percent_encoding::{USERINFO_ENCODE_SET, percent_encode};
use super::connector::{ClientConnector, Connection};
use super::pipeline::SendRequest;
use body::Body;
use error::Error;
use header::{ContentEncoding, Header, IntoHeaderValue};
use http::header::{self, HeaderName, HeaderValue};
use http::{uri, Error as HttpError, HeaderMap, HttpTryFrom, Method, Uri, Version};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use http::{uri, HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use super::pipeline::SendRequest;
use super::connector::{Connection, ClientConnector};
/// An HTTP Client Request
///
@ -72,7 +72,6 @@ enum ConnectionType {
}
impl Default for ClientRequest {
fn default() -> ClientRequest {
ClientRequest {
uri: Uri::default(),
@ -92,7 +91,6 @@ impl Default for ClientRequest {
}
impl ClientRequest {
/// Create request builder for `GET` request
pub fn get<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build();
@ -130,14 +128,13 @@ impl ClientRequest {
}
impl ClientRequest {
/// Create client request builder
pub fn build() -> ClientRequestBuilder {
ClientRequestBuilder {
request: Some(ClientRequest::default()),
err: None,
cookies: None,
default_headers: true
default_headers: true,
}
}
@ -259,8 +256,11 @@ impl ClientRequest {
impl fmt::Debug for ClientRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(f, "\nClientRequest {:?} {}:{}",
self.version, self.method, self.uri);
let res = writeln!(
f,
"\nClientRequest {:?} {}:{}",
self.version, self.method, self.uri
);
let _ = writeln!(f, " headers:");
for (key, val) in self.headers.iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val);
@ -277,7 +277,7 @@ pub struct ClientRequestBuilder {
request: Option<ClientRequest>,
err: Option<HttpError>,
cookies: Option<CookieJar>,
default_headers: bool
default_headers: bool,
}
impl ClientRequestBuilder {
@ -300,8 +300,8 @@ impl ClientRequestBuilder {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.uri = uri;
}
},
Err(e) => self.err = Some(e.into(),),
}
Err(e) => self.err = Some(e.into()),
}
self
}
@ -318,8 +318,8 @@ impl ClientRequestBuilder {
/// Set HTTP method of this request.
#[inline]
pub fn get_method(&mut self) -> &Method {
let parts = parts(&mut self.request, &self.err)
.expect("cannot reuse request builder");
let parts =
parts(&mut self.request, &self.err).expect("cannot reuse request builder");
&parts.method
}
@ -351,11 +351,12 @@ impl ClientRequestBuilder {
/// }
/// ```
#[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self
{
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
match hdr.try_into() {
Ok(value) => { parts.headers.insert(H::name(), value); }
Ok(value) => {
parts.headers.insert(H::name(), value);
}
Err(e) => self.err = Some(e.into()),
}
}
@ -382,15 +383,17 @@ impl ClientRequestBuilder {
/// }
/// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => {
match value.try_into() {
Ok(value) => { parts.headers.append(key, value); }
Err(e) => self.err = Some(e.into()),
Ok(key) => match value.try_into() {
Ok(value) => {
parts.headers.append(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
};
@ -400,15 +403,17 @@ impl ClientRequestBuilder {
/// Set a header.
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => {
match value.try_into() {
Ok(value) => { parts.headers.insert(key, value); }
Err(e) => self.err = Some(e.into()),
Ok(key) => match value.try_into() {
Ok(value) => {
parts.headers.insert(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
};
@ -448,11 +453,14 @@ impl ClientRequestBuilder {
/// Set request's content type
#[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self
where HeaderValue: HttpTryFrom<V>
where
HeaderValue: HttpTryFrom<V>,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); },
Ok(value) => {
parts.headers.insert(header::CONTENT_TYPE, value);
}
Err(e) => self.err = Some(e.into()),
};
}
@ -491,7 +499,10 @@ impl ClientRequestBuilder {
jar.add(cookie.into_owned());
self.cookies = Some(jar)
} else {
self.cookies.as_mut().unwrap().add(cookie.into_owned());
self.cookies
.as_mut()
.unwrap()
.add(cookie.into_owned());
}
self
}
@ -551,7 +562,8 @@ impl ClientRequestBuilder {
/// This method calls provided closure with builder reference if
/// value is `true`.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where F: FnOnce(&mut ClientRequestBuilder)
where
F: FnOnce(&mut ClientRequestBuilder),
{
if value {
f(self);
@ -562,7 +574,8 @@ impl ClientRequestBuilder {
/// This method calls provided closure with builder reference if
/// value is `Some`.
pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self
where F: FnOnce(T, &mut ClientRequestBuilder)
where
F: FnOnce(T, &mut ClientRequestBuilder),
{
if let Some(val) = value {
f(val, self);
@ -575,18 +588,20 @@ impl ClientRequestBuilder {
/// `ClientRequestBuilder` can not be used after this call.
pub fn body<B: Into<Body>>(&mut self, body: B) -> Result<ClientRequest, Error> {
if let Some(e) = self.err.take() {
return Err(e.into())
return Err(e.into());
}
if self.default_headers {
// enable br only for https
let https =
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.uri.scheme_part()
.map(|s| s == &uri::Scheme::HTTPS).unwrap_or(true)
} else {
true
};
let https = if let Some(parts) = parts(&mut self.request, &self.err) {
parts
.uri
.scheme_part()
.map(|s| s == &uri::Scheme::HTTPS)
.unwrap_or(true)
} else {
true
};
if https {
self.header(header::ACCEPT_ENCODING, "br, gzip, deflate");
@ -595,7 +610,9 @@ impl ClientRequestBuilder {
}
}
let mut request = self.request.take().expect("cannot reuse request builder");
let mut request = self.request
.take()
.expect("cannot reuse request builder");
// set cookies
if let Some(ref mut jar) = self.cookies {
@ -606,7 +623,9 @@ impl ClientRequestBuilder {
let _ = write!(&mut cookie, "; {}={}", name, value);
}
request.headers.insert(
header::COOKIE, HeaderValue::from_str(&cookie.as_str()[2..]).unwrap());
header::COOKIE,
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
);
}
request.body = body.into();
Ok(request)
@ -634,10 +653,13 @@ impl ClientRequestBuilder {
///
/// `ClientRequestBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> Result<ClientRequest, Error>
where S: Stream<Item=Bytes, Error=E> + 'static,
E: Into<Error>,
where
S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error>,
{
self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into()))))
self.body(Body::Streaming(Box::new(
stream.map_err(|e| e.into()),
)))
}
/// Set an empty body and generate `ClientRequest`
@ -653,17 +675,17 @@ impl ClientRequestBuilder {
request: self.request.take(),
err: self.err.take(),
cookies: self.cookies.take(),
default_headers: self.default_headers
default_headers: self.default_headers,
}
}
}
#[inline]
fn parts<'a>(parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>)
-> Option<&'a mut ClientRequest>
{
fn parts<'a>(
parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>
) -> Option<&'a mut ClientRequest> {
if err.is_some() {
return None
return None;
}
parts.as_mut()
}
@ -671,8 +693,11 @@ fn parts<'a>(parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>)
impl fmt::Debug for ClientRequestBuilder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(ref parts) = self.request {
let res = writeln!(f, "\nClientRequestBuilder {:?} {}:{}",
parts.version, parts.method, parts.uri);
let res = writeln!(
f,
"\nClientRequestBuilder {:?} {}:{}",
parts.version, parts.method, parts.uri
);
let _ = writeln!(f, " headers:");
for (key, val) in parts.headers.iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val);

View File

@ -1,19 +1,18 @@
use std::{fmt, str};
use std::rc::Rc;
use std::cell::UnsafeCell;
use std::rc::Rc;
use std::{fmt, str};
use bytes::Bytes;
use cookie::Cookie;
use futures::{Async, Poll, Stream};
use http::{HeaderMap, StatusCode, Version};
use http::header::{self, HeaderValue};
use http::{HeaderMap, StatusCode, Version};
use httpmessage::HttpMessage;
use error::{CookieParseError, PayloadError};
use httpmessage::HttpMessage;
use super::pipeline::Pipeline;
pub(crate) struct ClientMessage {
pub status: StatusCode,
pub version: Version,
@ -22,7 +21,6 @@ pub(crate) struct ClientMessage {
}
impl Default for ClientMessage {
fn default() -> ClientMessage {
ClientMessage {
status: StatusCode::OK,
@ -45,7 +43,6 @@ impl HttpMessage for ClientResponse {
}
impl ClientResponse {
pub(crate) fn new(msg: ClientMessage) -> ClientResponse {
ClientResponse(Rc::new(UnsafeCell::new(msg)), None)
}
@ -56,13 +53,13 @@ impl ClientResponse {
#[inline]
fn as_ref(&self) -> &ClientMessage {
unsafe{ &*self.0.get() }
unsafe { &*self.0.get() }
}
#[inline]
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))]
fn as_mut(&self) -> &mut ClientMessage {
unsafe{ &mut *self.0.get() }
unsafe { &mut *self.0.get() }
}
/// Get the HTTP version of this response.
@ -96,7 +93,7 @@ impl ClientResponse {
if let Ok(cookies) = self.cookies() {
for cookie in cookies {
if cookie.name() == name {
return Some(cookie)
return Some(cookie);
}
}
}
@ -107,7 +104,11 @@ impl ClientResponse {
impl fmt::Debug for ClientResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(
f, "\nClientResponse {:?} {}", self.version(), self.status());
f,
"\nClientResponse {:?} {}",
self.version(),
self.status()
);
let _ = writeln!(f, " headers:");
for (key, val) in self.headers().iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val);
@ -138,9 +139,13 @@ mod tests {
fn test_debug() {
let resp = ClientResponse::new(ClientMessage::default());
resp.as_mut().headers.insert(
header::COOKIE, HeaderValue::from_static("cookie1=value1"));
header::COOKIE,
HeaderValue::from_static("cookie1=value1"),
);
resp.as_mut().headers.insert(
header::COOKIE, HeaderValue::from_static("cookie2=value2"));
header::COOKIE,
HeaderValue::from_static("cookie2=value2"),
);
let dbg = format!("{:?}", resp);
assert!(dbg.contains("ClientResponse"));

View File

@ -1,30 +1,29 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::io::{self, Write};
use std::cell::RefCell;
use std::fmt::Write as FmtWrite;
use std::io::{self, Write};
use time::{self, Duration};
use bytes::{BytesMut, BufMut};
use futures::{Async, Poll};
use tokio_io::AsyncWrite;
use http::{Version, HttpTryFrom};
use http::header::{HeaderValue, DATE,
CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use flate2::Compression;
use flate2::write::{GzEncoder, DeflateEncoder};
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
use brotli2::write::BrotliEncoder;
use bytes::{BufMut, BytesMut};
use flate2::Compression;
use flate2::write::{DeflateEncoder, GzEncoder};
use futures::{Async, Poll};
use http::header::{HeaderValue, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, DATE,
TRANSFER_ENCODING};
use http::{HttpTryFrom, Version};
use time::{self, Duration};
use tokio_io::AsyncWrite;
use body::{Body, Binary};
use body::{Binary, Body};
use header::ContentEncoding;
use server::WriterState;
use server::shared::SharedBytes;
use server::encoding::{ContentEncoder, TransferEncoding};
use server::shared::SharedBytes;
use client::ClientRequest;
const AVERAGE_HEADER_SIZE: usize = 30;
bitflags! {
@ -46,7 +45,6 @@ pub(crate) struct HttpClientWriter {
}
impl HttpClientWriter {
pub fn new(buffer: SharedBytes) -> HttpClientWriter {
let encoder = ContentEncoder::Identity(TransferEncoding::eof(buffer.clone()));
HttpClientWriter {
@ -64,24 +62,26 @@ impl HttpClientWriter {
}
// pub fn keepalive(&self) -> bool {
// self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE)
// }
// self.flags.contains(Flags::KEEPALIVE) &&
// !self.flags.contains(Flags::UPGRADE) }
fn write_to_stream<T: AsyncWrite>(&mut self, stream: &mut T) -> io::Result<WriterState> {
fn write_to_stream<T: AsyncWrite>(
&mut self, stream: &mut T
) -> io::Result<WriterState> {
while !self.buffer.is_empty() {
match stream.write(self.buffer.as_ref()) {
Ok(0) => {
self.disconnected();
return Ok(WriterState::Done);
},
}
Ok(n) => {
let _ = self.buffer.split_to(n);
},
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if self.buffer.len() > self.buffer_capacity {
return Ok(WriterState::Pause)
return Ok(WriterState::Pause);
} else {
return Ok(WriterState::Done)
return Ok(WriterState::Done);
}
}
Err(err) => return Err(err),
@ -92,7 +92,6 @@ impl HttpClientWriter {
}
impl HttpClientWriter {
pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> {
// prepare task
self.flags.insert(Flags::STARTED);
@ -105,10 +104,16 @@ impl HttpClientWriter {
// render message
{
// status line
writeln!(self.buffer, "{} {} {:?}\r",
msg.method(),
msg.uri().path_and_query().map(|u| u.as_str()).unwrap_or("/"),
msg.version())?;
writeln!(
self.buffer,
"{} {} {:?}\r",
msg.method(),
msg.uri()
.path_and_query()
.map(|u| u.as_str())
.unwrap_or("/"),
msg.version()
)?;
// write headers
let mut buffer = self.buffer.get_mut();
@ -173,15 +178,17 @@ impl HttpClientWriter {
if self.encoder.is_eof() {
Ok(())
} else {
Err(io::Error::new(io::ErrorKind::Other,
"Last payload item, but eof is not reached"))
Err(io::Error::new(
io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
}
}
#[inline]
pub fn poll_completed<T: AsyncWrite>(&mut self, stream: &mut T, shutdown: bool)
-> Poll<(), io::Error>
{
pub fn poll_completed<T: AsyncWrite>(
&mut self, stream: &mut T, shutdown: bool
) -> Poll<(), io::Error> {
match self.write_to_stream(stream) {
Ok(WriterState::Done) => {
if shutdown {
@ -189,14 +196,13 @@ impl HttpClientWriter {
} else {
Ok(Async::Ready(()))
}
},
}
Ok(WriterState::Pause) => Ok(Async::NotReady),
Err(err) => Err(err)
Err(err) => Err(err),
}
}
}
fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder {
let version = req.version();
let mut body = req.replace_body(Body::Empty);
@ -206,21 +212,25 @@ fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder
Body::Empty => {
req.headers_mut().remove(CONTENT_LENGTH);
TransferEncoding::length(0, buf)
},
}
Body::Binary(ref mut bytes) => {
if encoding.is_compression() {
let tmp = SharedBytes::default();
let transfer = TransferEncoding::eof(tmp.clone());
let mut enc = match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate(
DeflateEncoder::new(transfer, Compression::default())),
ContentEncoding::Gzip => ContentEncoder::Gzip(
GzEncoder::new(transfer, Compression::default())),
#[cfg(feature="brotli")]
ContentEncoding::Br => ContentEncoder::Br(
BrotliEncoder::new(transfer, 5)),
DeflateEncoder::new(transfer, Compression::default()),
),
ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new(
transfer,
Compression::default(),
)),
#[cfg(feature = "brotli")]
ContentEncoding::Br => {
ContentEncoder::Br(BrotliEncoder::new(transfer, 5))
}
ContentEncoding::Identity => ContentEncoder::Identity(transfer),
ContentEncoding::Auto => unreachable!()
ContentEncoding::Auto => unreachable!(),
};
// TODO return error!
let _ = enc.write(bytes.clone());
@ -228,21 +238,26 @@ fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder
*bytes = Binary::from(tmp.take());
req.headers_mut().insert(
CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str()));
CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
encoding = ContentEncoding::Identity;
}
let mut b = BytesMut::new();
let _ = write!(b, "{}", bytes.len());
req.headers_mut().insert(
CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
CONTENT_LENGTH,
HeaderValue::try_from(b.freeze()).unwrap(),
);
TransferEncoding::eof(buf)
},
}
Body::Streaming(_) | Body::Actor(_) => {
if req.upgrade() {
if version == Version::HTTP_2 {
error!("Connection upgrade is forbidden for HTTP/2");
} else {
req.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade"));
req.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
if encoding != ContentEncoding::Identity {
encoding = ContentEncoding::Identity;
@ -257,24 +272,31 @@ fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder
if encoding.is_compression() {
req.headers_mut().insert(
CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str()));
CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
}
req.replace_body(body);
match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate(
DeflateEncoder::new(transfer, Compression::default())),
ContentEncoding::Gzip => ContentEncoder::Gzip(
GzEncoder::new(transfer, Compression::default())),
#[cfg(feature="brotli")]
ContentEncoding::Br => ContentEncoder::Br(
BrotliEncoder::new(transfer, 5)),
ContentEncoding::Identity | ContentEncoding::Auto => ContentEncoder::Identity(transfer),
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer,
Compression::default(),
)),
ContentEncoding::Gzip => {
ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::default()))
}
#[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
}
}
}
fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientRequest)
-> TransferEncoding {
fn streaming_encoding(
buf: SharedBytes, version: Version, req: &mut ClientRequest
) -> TransferEncoding {
if req.chunked() {
// Enable transfer encoding
req.headers_mut().remove(CONTENT_LENGTH);
@ -282,29 +304,28 @@ fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientReques
req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf)
} else {
req.headers_mut().insert(
TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked(buf)
}
} else {
// if Content-Length is specified, then use it as length hint
let (len, chunked) =
if let Some(len) = req.headers().get(CONTENT_LENGTH) {
// Content-Length
if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() {
(Some(len), false)
} else {
error!("illegal Content-Length: {:?}", len);
(None, false)
}
let (len, chunked) = if let Some(len) = req.headers().get(CONTENT_LENGTH) {
// Content-Length
if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() {
(Some(len), false)
} else {
error!("illegal Content-Length: {:?}", len);
(None, false)
}
} else {
(None, true)
};
error!("illegal Content-Length: {:?}", len);
(None, false)
}
} else {
(None, true)
};
if !chunked {
if let Some(len) = len {
@ -316,10 +337,10 @@ fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientReques
// Enable transfer encoding
match version {
Version::HTTP_11 => {
req.headers_mut().insert(
TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked(buf)
},
}
_ => {
req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf)
@ -329,7 +350,6 @@ fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientReques
}
}
// "Sun, 06 Nov 1994 08:49:37 GMT".len()
pub const DATE_VALUE_LENGTH: usize = 29;

View File

@ -1,20 +1,19 @@
use std::mem;
use std::marker::PhantomData;
use futures::{Async, Future, Poll};
use futures::sync::oneshot::Sender;
use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use smallvec::SmallVec;
use std::marker::PhantomData;
use std::mem;
use actix::{Actor, ActorState, ActorContext, AsyncContext,
Addr, Handler, Message, SpawnHandle, Syn, Unsync};
use actix::dev::{ContextImpl, SyncEnvelope, ToEnvelope};
use actix::fut::ActorFuture;
use actix::dev::{ContextImpl, ToEnvelope, SyncEnvelope};
use actix::{Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message,
SpawnHandle, Syn, Unsync};
use body::{Body, Binary};
use body::{Binary, Body};
use error::{Error, ErrorInternalServerError};
use httprequest::HttpRequest;
pub trait ActorHttpContext: 'static {
fn disconnected(&mut self);
fn poll(&mut self) -> Poll<Option<SmallVec<[Frame; 4]>>, Error>;
@ -36,7 +35,9 @@ impl Frame {
}
/// Execution context for http actors
pub struct HttpContext<A, S=()> where A: Actor<Context=HttpContext<A, S>>,
pub struct HttpContext<A, S = ()>
where
A: Actor<Context = HttpContext<A, S>>,
{
inner: ContextImpl<A>,
stream: Option<SmallVec<[Frame; 4]>>,
@ -44,7 +45,9 @@ pub struct HttpContext<A, S=()> where A: Actor<Context=HttpContext<A, S>>,
disconnected: bool,
}
impl<A, S> ActorContext for HttpContext<A, S> where A: Actor<Context=Self>
impl<A, S> ActorContext for HttpContext<A, S>
where
A: Actor<Context = Self>,
{
fn stop(&mut self) {
self.inner.stop();
@ -57,25 +60,29 @@ impl<A, S> ActorContext for HttpContext<A, S> where A: Actor<Context=Self>
}
}
impl<A, S> AsyncContext<A> for HttpContext<A, S> where A: Actor<Context=Self>
impl<A, S> AsyncContext<A> for HttpContext<A, S>
where
A: Actor<Context = Self>,
{
#[inline]
fn spawn<F>(&mut self, fut: F) -> SpawnHandle
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{
self.inner.spawn(fut)
}
#[inline]
fn wait<F>(&mut self, fut: F)
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{
self.inner.wait(fut)
}
#[doc(hidden)]
#[inline]
fn waiting(&self) -> bool {
self.inner.waiting() || self.inner.state() == ActorState::Stopping ||
self.inner.state() == ActorState::Stopped
self.inner.waiting() || self.inner.state() == ActorState::Stopping
|| self.inner.state() == ActorState::Stopped
}
#[inline]
fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
@ -93,8 +100,10 @@ impl<A, S> AsyncContext<A> for HttpContext<A, S> where A: Actor<Context=Self>
}
}
impl<A, S: 'static> HttpContext<A, S> where A: Actor<Context=Self> {
impl<A, S: 'static> HttpContext<A, S>
where
A: Actor<Context = Self>,
{
#[inline]
pub fn new(req: HttpRequest<S>, actor: A) -> HttpContext<A, S> {
HttpContext::from_request(req).actor(actor)
@ -114,8 +123,10 @@ impl<A, S: 'static> HttpContext<A, S> where A: Actor<Context=Self> {
}
}
impl<A, S> HttpContext<A, S> where A: Actor<Context=Self> {
impl<A, S> HttpContext<A, S>
where
A: Actor<Context = Self>,
{
/// Shared application state
#[inline]
pub fn state(&self) -> &S {
@ -175,8 +186,11 @@ impl<A, S> HttpContext<A, S> where A: Actor<Context=Self> {
}
}
impl<A, S> ActorHttpContext for HttpContext<A, S> where A: Actor<Context=Self>, S: 'static {
impl<A, S> ActorHttpContext for HttpContext<A, S>
where
A: Actor<Context = Self>,
S: 'static,
{
#[inline]
fn disconnected(&mut self) {
self.disconnected = true;
@ -184,9 +198,8 @@ impl<A, S> ActorHttpContext for HttpContext<A, S> where A: Actor<Context=Self>,
}
fn poll(&mut self) -> Poll<Option<SmallVec<[Frame; 4]>>, Error> {
let ctx: &mut HttpContext<A, S> = unsafe {
mem::transmute(self as &mut HttpContext<A, S>)
};
let ctx: &mut HttpContext<A, S> =
unsafe { mem::transmute(self as &mut HttpContext<A, S>) };
if self.inner.alive() {
match self.inner.poll(ctx) {
@ -207,8 +220,10 @@ impl<A, S> ActorHttpContext for HttpContext<A, S> where A: Actor<Context=Self>,
}
impl<A, M, S> ToEnvelope<Syn, A, M> for HttpContext<A, S>
where A: Actor<Context=HttpContext<A, S>> + Handler<M>,
M: Message + Send + 'static, M::Result: Send,
where
A: Actor<Context = HttpContext<A, S>> + Handler<M>,
M: Message + Send + 'static,
M::Result: Send,
{
fn pack(msg: M, tx: Option<Sender<M::Result>>) -> SyncEnvelope<A> {
SyncEnvelope::new(msg, tx)
@ -216,8 +231,9 @@ impl<A, M, S> ToEnvelope<Syn, A, M> for HttpContext<A, S>
}
impl<A, S> From<HttpContext<A, S>> for Body
where A: Actor<Context=HttpContext<A, S>>,
S: 'static
where
A: Actor<Context = HttpContext<A, S>>,
S: 'static,
{
fn from(ctx: HttpContext<A, S>) -> Body {
Body::Actor(Box::new(ctx))
@ -231,7 +247,10 @@ pub struct Drain<A> {
impl<A> Drain<A> {
pub fn new(fut: oneshot::Receiver<()>) -> Self {
Drain { fut, _a: PhantomData }
Drain {
fut,
_a: PhantomData,
}
}
}
@ -241,10 +260,9 @@ impl<A: Actor> ActorFuture for Drain<A> {
type Actor = A;
#[inline]
fn poll(&mut self,
_: &mut A,
_: &mut <Self::Actor as Actor>::Context) -> Poll<Self::Item, Self::Error>
{
fn poll(
&mut self, _: &mut A, _: &mut <Self::Actor as Actor>::Context
) -> Poll<Self::Item, Self::Error> {
self.fut.poll().map_err(|_| ())
}
}

228
src/de.rs
View File

@ -1,11 +1,10 @@
use std::slice::Iter;
use serde::de::{self, Deserializer, Error as DeError, Visitor};
use std::borrow::Cow;
use std::convert::AsRef;
use serde::de::{self, Deserializer, Visitor, Error as DeError};
use std::slice::Iter;
use httprequest::HttpRequest;
macro_rules! unsupported_type {
($trait_fn:ident, $name:expr) => {
fn $trait_fn<V>(self, _: V) -> Result<V::Value, Self::Error>
@ -37,103 +36,136 @@ macro_rules! parse_single_value {
}
pub struct PathDeserializer<'de, S: 'de> {
req: &'de HttpRequest<S>
req: &'de HttpRequest<S>,
}
impl<'de, S: 'de> PathDeserializer<'de, S> {
pub fn new(req: &'de HttpRequest<S>) -> Self {
PathDeserializer{req}
PathDeserializer { req }
}
}
impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S>
{
impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> {
type Error = de::value::Error;
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_map(ParamsDeserializer{
visitor.visit_map(ParamsDeserializer {
params: self.req.match_info().iter(),
current: None,
})
}
fn deserialize_struct<V>(self, _: &'static str, _: &'static [&'static str], visitor: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>,
fn deserialize_struct<V>(
self, _: &'static str, _: &'static [&'static str], visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_map(visitor)
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(self, _: &'static str, visitor: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>
fn deserialize_unit_struct<V>(
self, _: &'static str, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, _: &'static str, visitor: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>,
fn deserialize_newtype_struct<V>(
self, _: &'static str, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>
fn deserialize_tuple<V>(
self, len: usize, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.req.match_info().len() < len {
Err(de::value::Error::custom(
format!("wrong number of parameters: {} expected {}",
self.req.match_info().len(), len).as_str()))
format!(
"wrong number of parameters: {} expected {}",
self.req.match_info().len(),
len
).as_str(),
))
} else {
visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()})
visitor.visit_seq(ParamsSeq {
params: self.req.match_info().iter(),
})
}
}
fn deserialize_tuple_struct<V>(self, _: &'static str, len: usize, visitor: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>
fn deserialize_tuple_struct<V>(
self, _: &'static str, len: usize, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.req.match_info().len() < len {
Err(de::value::Error::custom(
format!("wrong number of parameters: {} expected {}",
self.req.match_info().len(), len).as_str()))
format!(
"wrong number of parameters: {} expected {}",
self.req.match_info().len(),
len
).as_str(),
))
} else {
visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()})
visitor.visit_seq(ParamsSeq {
params: self.req.match_info().iter(),
})
}
}
fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], _: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>
fn deserialize_enum<V>(
self, _: &'static str, _: &'static [&'static str], _: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("unsupported type: enum"))
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
if self.req.match_info().len() != 1 {
Err(de::value::Error::custom(
format!("wrong number of parameters: {} expected 1",
self.req.match_info().len()).as_str()))
format!(
"wrong number of parameters: {} expected 1",
self.req.match_info().len()
).as_str(),
))
} else {
visitor.visit_str(&self.req.match_info()[0])
}
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>
where
V: Visitor<'de>,
{
visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()})
visitor.visit_seq(ParamsSeq {
params: self.req.match_info().iter(),
})
}
unsupported_type!(deserialize_any, "'any'");
@ -163,22 +195,25 @@ struct ParamsDeserializer<'de> {
current: Option<(&'de str, &'de str)>,
}
impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de>
{
impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de> {
type Error = de::value::Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where K: de::DeserializeSeed<'de>,
where
K: de::DeserializeSeed<'de>,
{
self.current = self.params.next().map(|&(ref k, ref v)| (k.as_ref(), v.as_ref()));
self.current = self.params
.next()
.map(|&(ref k, ref v)| (k.as_ref(), v.as_ref()));
match self.current {
Some((key, _)) => Ok(Some(seed.deserialize(Key{key})?)),
Some((key, _)) => Ok(Some(seed.deserialize(Key { key })?)),
None => Ok(None),
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where V: de::DeserializeSeed<'de>,
where
V: de::DeserializeSeed<'de>,
{
if let Some((_, value)) = self.current.take() {
seed.deserialize(Value { value })
@ -196,13 +231,15 @@ impl<'de> Deserializer<'de> for Key<'de> {
type Error = de::value::Error;
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_str(self.key)
}
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("Unexpected"))
}
@ -231,8 +268,7 @@ struct Value<'de> {
value: &'de str,
}
impl<'de> Deserializer<'de> for Value<'de>
{
impl<'de> Deserializer<'de> for Value<'de> {
type Error = de::value::Error;
parse_value!(deserialize_bool, visit_bool, "bool");
@ -251,74 +287,94 @@ impl<'de> Deserializer<'de> for Value<'de>
parse_value!(deserialize_char, visit_char, "char");
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(
self, _: &'static str, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>
self, _: &'static str, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_borrowed_bytes(self.value.as_bytes())
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_borrowed_str(self.value)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
visitor.visit_some(self)
}
fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], visitor: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>,
fn deserialize_enum<V>(
self, _: &'static str, _: &'static [&'static str], visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_enum(ValueEnum {value: self.value})
visitor.visit_enum(ValueEnum {
value: self.value,
})
}
fn deserialize_newtype_struct<V>(self, _: &'static str, visitor: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>,
fn deserialize_newtype_struct<V>(
self, _: &'static str, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_tuple<V>(self, _: usize, _: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("unsupported type: tuple"))
}
fn deserialize_struct<V>(self, _: &'static str, _: &'static [&'static str], _: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>
fn deserialize_struct<V>(
self, _: &'static str, _: &'static [&'static str], _: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("unsupported type: struct"))
}
fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>
fn deserialize_tuple_struct<V>(
self, _: &'static str, _: usize, _: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("unsupported type: tuple struct"))
Err(de::value::Error::custom(
"unsupported type: tuple struct",
))
}
unsupported_type!(deserialize_any, "any");
@ -331,15 +387,17 @@ struct ParamsSeq<'de> {
params: Iter<'de, (Cow<'de, str>, Cow<'de, str>)>,
}
impl<'de> de::SeqAccess<'de> for ParamsSeq<'de>
{
impl<'de> de::SeqAccess<'de> for ParamsSeq<'de> {
type Error = de::value::Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where T: de::DeserializeSeed<'de>,
where
T: de::DeserializeSeed<'de>,
{
match self.params.next() {
Some(item) => Ok(Some(seed.deserialize(Value { value: item.1.as_ref() })?)),
Some(item) => Ok(Some(seed.deserialize(Value {
value: item.1.as_ref(),
})?)),
None => Ok(None),
}
}
@ -354,9 +412,13 @@ impl<'de> de::EnumAccess<'de> for ValueEnum<'de> {
type Variant = UnitVariant;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where V: de::DeserializeSeed<'de>,
where
V: de::DeserializeSeed<'de>,
{
Ok((seed.deserialize(Key { key: self.value })?, UnitVariant))
Ok((
seed.deserialize(Key { key: self.value })?,
UnitVariant,
))
}
}
@ -370,20 +432,24 @@ impl<'de> de::VariantAccess<'de> for UnitVariant {
}
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Self::Error>
where T: de::DeserializeSeed<'de>,
where
T: de::DeserializeSeed<'de>,
{
Err(de::value::Error::custom("not supported"))
}
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>,
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("not supported"))
}
fn struct_variant<V>(self, _: &'static [&'static str], _: V)
-> Result<V::Value, Self::Error>
where V: Visitor<'de>,
fn struct_variant<V>(
self, _: &'static [&'static str], _: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(de::value::Error::custom("not supported"))
}

View File

@ -1,24 +1,24 @@
//! Error and Result module
use std::{io, fmt, result};
use std::io::Error as IoError;
use std::str::Utf8Error;
use std::string::FromUtf8Error;
use std::io::Error as IoError;
use std::{fmt, io, result};
use cookie;
use httparse;
use actix::MailboxError;
use cookie;
use failure::{self, Backtrace, Fail};
use futures::Canceled;
use failure::{self, Fail, Backtrace};
use http2::Error as Http2Error;
use http::{header, StatusCode, Error as HttpError};
use http::uri::InvalidUri;
use http::{header, Error as HttpError, StatusCode};
use http2::Error as Http2Error;
use http_range::HttpRangeParseError;
use httparse;
use serde::de::value::Error as DeError;
use serde_json::error::Error as JsonError;
pub use url::ParseError as UrlParseError;
// re-exports
pub use cookie::{ParseError as CookieParseError};
pub use cookie::ParseError as CookieParseError;
use handler::Responder;
use httprequest::HttpRequest;
@ -27,9 +27,10 @@ use httpresponse::HttpResponse;
/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html)
/// for actix web operations
///
/// This typedef is generally used to avoid writing out `actix_web::error::Error` directly and
/// is otherwise a direct mapping to `Result`.
pub type Result<T, E=Error> = result::Result<T, E>;
/// This typedef is generally used to avoid writing out
/// `actix_web::error::Error` directly and is otherwise a direct mapping to
/// `Result`.
pub type Result<T, E = Error> = result::Result<T, E>;
/// General purpose actix web error
pub struct Error {
@ -38,7 +39,6 @@ pub struct Error {
}
impl Error {
/// Returns a reference to the underlying cause of this Error.
// this should return &Fail but needs this https://github.com/rust-lang/rust/issues/5665
pub fn cause(&self) -> &ResponseError {
@ -48,7 +48,6 @@ impl Error {
/// Error that can be converted to `HttpResponse`
pub trait ResponseError: Fail {
/// Create response for error
///
/// Internal server error is generated by default.
@ -68,7 +67,12 @@ impl fmt::Debug for Error {
if let Some(bt) = self.cause.backtrace() {
write!(f, "{:?}\n\n{:?}", &self.cause, bt)
} else {
write!(f, "{:?}\n\n{:?}", &self.cause, self.backtrace.as_ref().unwrap())
write!(
f,
"{:?}\n\n{:?}",
&self.cause,
self.backtrace.as_ref().unwrap()
)
}
}
}
@ -88,13 +92,19 @@ impl<T: ResponseError> From<T> for Error {
} else {
None
};
Error { cause: Box::new(err), backtrace }
Error {
cause: Box::new(err),
backtrace,
}
}
}
/// Compatibility for `failure::Error`
impl<T> ResponseError for failure::Compat<T>
where T: fmt::Display + fmt::Debug + Sync + Send + 'static { }
where
T: fmt::Display + fmt::Debug + Sync + Send + 'static,
{
}
impl From<failure::Error> for Error {
fn from(err: failure::Error) -> Error {
@ -128,15 +138,11 @@ impl ResponseError for HttpError {}
/// Return `InternalServerError` for `io::Error`
impl ResponseError for io::Error {
fn error_response(&self) -> HttpResponse {
match self.kind() {
io::ErrorKind::NotFound =>
HttpResponse::new(StatusCode::NOT_FOUND),
io::ErrorKind::PermissionDenied =>
HttpResponse::new(StatusCode::FORBIDDEN),
_ =>
HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
io::ErrorKind::NotFound => HttpResponse::new(StatusCode::NOT_FOUND),
io::ErrorKind::PermissionDenied => HttpResponse::new(StatusCode::FORBIDDEN),
_ => HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR),
}
}
}
@ -145,7 +151,7 @@ impl ResponseError for io::Error {
impl ResponseError for header::InvalidHeaderValue {
fn error_response(&self) -> HttpResponse {
HttpResponse::new(StatusCode::BAD_REQUEST)
}
}
}
/// `BadRequest` for `InvalidHeaderValue`
@ -165,35 +171,36 @@ impl ResponseError for MailboxError {}
#[derive(Fail, Debug)]
pub enum ParseError {
/// An invalid `Method`, such as `GE.T`.
#[fail(display="Invalid Method specified")]
#[fail(display = "Invalid Method specified")]
Method,
/// An invalid `Uri`, such as `exam ple.domain`.
#[fail(display="Uri error: {}", _0)]
#[fail(display = "Uri error: {}", _0)]
Uri(InvalidUri),
/// An invalid `HttpVersion`, such as `HTP/1.1`
#[fail(display="Invalid HTTP version specified")]
#[fail(display = "Invalid HTTP version specified")]
Version,
/// An invalid `Header`.
#[fail(display="Invalid Header provided")]
#[fail(display = "Invalid Header provided")]
Header,
/// A message head is too large to be reasonable.
#[fail(display="Message head is too large")]
#[fail(display = "Message head is too large")]
TooLarge,
/// A message reached EOF, but is not complete.
#[fail(display="Message is incomplete")]
#[fail(display = "Message is incomplete")]
Incomplete,
/// An invalid `Status`, such as `1337 ELITE`.
#[fail(display="Invalid Status provided")]
#[fail(display = "Invalid Status provided")]
Status,
/// A timeout occurred waiting for an IO event.
#[allow(dead_code)]
#[fail(display="Timeout")]
#[fail(display = "Timeout")]
Timeout,
/// An `io::Error` that occurred while trying to read or write to a network stream.
#[fail(display="IO error: {}", _0)]
/// An `io::Error` that occurred while trying to read or write to a network
/// stream.
#[fail(display = "IO error: {}", _0)]
Io(#[cause] IoError),
/// Parsing a field as string failed
#[fail(display="UTF8 error: {}", _0)]
#[fail(display = "UTF8 error: {}", _0)]
Utf8(#[cause] Utf8Error),
}
@ -231,8 +238,10 @@ impl From<FromUtf8Error> for ParseError {
impl From<httparse::Error> for ParseError {
fn from(err: httparse::Error) -> ParseError {
match err {
httparse::Error::HeaderName | httparse::Error::HeaderValue |
httparse::Error::NewLine | httparse::Error::Token => ParseError::Header,
httparse::Error::HeaderName
| httparse::Error::HeaderValue
| httparse::Error::NewLine
| httparse::Error::Token => ParseError::Header,
httparse::Error::Status => ParseError::Status,
httparse::Error::TooManyHeaders => ParseError::TooLarge,
httparse::Error::Version => ParseError::Version,
@ -244,22 +253,22 @@ impl From<httparse::Error> for ParseError {
/// A set of errors that can occur during payload parsing
pub enum PayloadError {
/// A payload reached EOF, but is not complete.
#[fail(display="A payload reached EOF, but is not complete.")]
#[fail(display = "A payload reached EOF, but is not complete.")]
Incomplete,
/// Content encoding stream corruption
#[fail(display="Can not decode content-encoding.")]
#[fail(display = "Can not decode content-encoding.")]
EncodingCorrupted,
/// A payload reached size limit.
#[fail(display="A payload reached size limit.")]
#[fail(display = "A payload reached size limit.")]
Overflow,
/// A payload length is unknown.
#[fail(display="A payload length is unknown.")]
#[fail(display = "A payload length is unknown.")]
UnknownLength,
/// Io error
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Io(#[cause] IoError),
/// Http2 error
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Http2(#[cause] Http2Error),
}
@ -283,12 +292,12 @@ impl ResponseError for cookie::ParseError {
#[derive(Fail, PartialEq, Debug)]
pub enum HttpRangeError {
/// Returned if range is invalid.
#[fail(display="Range header is invalid")]
#[fail(display = "Range header is invalid")]
InvalidRange,
/// Returned if first-byte-pos of all of the byte-range-spec
/// values is greater than the content size.
/// See `https://github.com/golang/go/commit/aa9b3d7`
#[fail(display="First-byte-pos of all of the byte-range-spec values is greater than the content size")]
#[fail(display = "First-byte-pos of all of the byte-range-spec values is greater than the content size")]
NoOverlap,
}
@ -296,7 +305,9 @@ pub enum HttpRangeError {
impl ResponseError for HttpRangeError {
fn error_response(&self) -> HttpResponse {
HttpResponse::with_body(
StatusCode::BAD_REQUEST, "Invalid Range header provided")
StatusCode::BAD_REQUEST,
"Invalid Range header provided",
)
}
}
@ -313,22 +324,22 @@ impl From<HttpRangeParseError> for HttpRangeError {
#[derive(Fail, Debug)]
pub enum MultipartError {
/// Content-Type header is not found
#[fail(display="No Content-type header found")]
#[fail(display = "No Content-type header found")]
NoContentType,
/// Can not parse Content-Type header
#[fail(display="Can not parse Content-Type header")]
#[fail(display = "Can not parse Content-Type header")]
ParseContentType,
/// Multipart boundary is not found
#[fail(display="Multipart boundary is not found")]
#[fail(display = "Multipart boundary is not found")]
Boundary,
/// Multipart stream is incomplete
#[fail(display="Multipart stream is incomplete")]
#[fail(display = "Multipart stream is incomplete")]
Incomplete,
/// Error during field parsing
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Parse(#[cause] ParseError),
/// Payload error
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Payload(#[cause] PayloadError),
}
@ -346,7 +357,6 @@ impl From<PayloadError> for MultipartError {
/// Return `BadRequest` for `MultipartError`
impl ResponseError for MultipartError {
fn error_response(&self) -> HttpResponse {
HttpResponse::new(StatusCode::BAD_REQUEST)
}
@ -356,10 +366,10 @@ impl ResponseError for MultipartError {
#[derive(Fail, PartialEq, Debug)]
pub enum ExpectError {
/// Expect header value can not be converted to utf8
#[fail(display="Expect header value can not be converted to utf8")]
#[fail(display = "Expect header value can not be converted to utf8")]
Encoding,
/// Unknown expect value
#[fail(display="Unknown expect value")]
#[fail(display = "Unknown expect value")]
UnknownExpect,
}
@ -373,10 +383,10 @@ impl ResponseError for ExpectError {
#[derive(Fail, PartialEq, Debug)]
pub enum ContentTypeError {
/// Can not parse content type
#[fail(display="Can not parse content type")]
#[fail(display = "Can not parse content type")]
ParseError,
/// Unknown content encoding
#[fail(display="Unknown content encoding")]
#[fail(display = "Unknown content encoding")]
UnknownEncoding,
}
@ -391,36 +401,36 @@ impl ResponseError for ContentTypeError {
#[derive(Fail, Debug)]
pub enum UrlencodedError {
/// Can not decode chunked transfer encoding
#[fail(display="Can not decode chunked transfer encoding")]
#[fail(display = "Can not decode chunked transfer encoding")]
Chunked,
/// Payload size is bigger than 256k
#[fail(display="Payload size is bigger than 256k")]
#[fail(display = "Payload size is bigger than 256k")]
Overflow,
/// Payload size is now known
#[fail(display="Payload size is now known")]
#[fail(display = "Payload size is now known")]
UnknownLength,
/// Content type error
#[fail(display="Content type error")]
#[fail(display = "Content type error")]
ContentType,
/// Parse error
#[fail(display="Parse error")]
#[fail(display = "Parse error")]
Parse,
/// Payload error
#[fail(display="Error that occur during reading payload: {}", _0)]
#[fail(display = "Error that occur during reading payload: {}", _0)]
Payload(#[cause] PayloadError),
}
/// Return `BadRequest` for `UrlencodedError`
impl ResponseError for UrlencodedError {
fn error_response(&self) -> HttpResponse {
match *self {
UrlencodedError::Overflow =>
HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE),
UrlencodedError::UnknownLength =>
HttpResponse::new(StatusCode::LENGTH_REQUIRED),
_ =>
HttpResponse::new(StatusCode::BAD_REQUEST),
UrlencodedError::Overflow => {
HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE)
}
UrlencodedError::UnknownLength => {
HttpResponse::new(StatusCode::LENGTH_REQUIRED)
}
_ => HttpResponse::new(StatusCode::BAD_REQUEST),
}
}
}
@ -435,28 +445,27 @@ impl From<PayloadError> for UrlencodedError {
#[derive(Fail, Debug)]
pub enum JsonPayloadError {
/// Payload size is bigger than 256k
#[fail(display="Payload size is bigger than 256k")]
#[fail(display = "Payload size is bigger than 256k")]
Overflow,
/// Content type error
#[fail(display="Content type error")]
#[fail(display = "Content type error")]
ContentType,
/// Deserialize error
#[fail(display="Json deserialize error: {}", _0)]
#[fail(display = "Json deserialize error: {}", _0)]
Deserialize(#[cause] JsonError),
/// Payload error
#[fail(display="Error that occur during reading payload: {}", _0)]
#[fail(display = "Error that occur during reading payload: {}", _0)]
Payload(#[cause] PayloadError),
}
/// Return `BadRequest` for `UrlencodedError`
impl ResponseError for JsonPayloadError {
fn error_response(&self) -> HttpResponse {
match *self {
JsonPayloadError::Overflow =>
HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE),
_ =>
HttpResponse::new(StatusCode::BAD_REQUEST),
JsonPayloadError::Overflow => {
HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE)
}
_ => HttpResponse::new(StatusCode::BAD_REQUEST),
}
}
}
@ -478,19 +487,18 @@ impl From<JsonError> for JsonPayloadError {
#[derive(Fail, Debug, PartialEq)]
pub enum UriSegmentError {
/// The segment started with the wrapped invalid character.
#[fail(display="The segment started with the wrapped invalid character")]
#[fail(display = "The segment started with the wrapped invalid character")]
BadStart(char),
/// The segment contained the wrapped invalid character.
#[fail(display="The segment contained the wrapped invalid character")]
#[fail(display = "The segment contained the wrapped invalid character")]
BadChar(char),
/// The segment ended with the wrapped invalid character.
#[fail(display="The segment ended with the wrapped invalid character")]
#[fail(display = "The segment ended with the wrapped invalid character")]
BadEnd(char),
}
/// Return `BadRequest` for `UriSegmentError`
impl ResponseError for UriSegmentError {
fn error_response(&self) -> HttpResponse {
HttpResponse::new(StatusCode::BAD_REQUEST)
}
@ -499,13 +507,13 @@ impl ResponseError for UriSegmentError {
/// Errors which can occur when attempting to generate resource uri.
#[derive(Fail, Debug, PartialEq)]
pub enum UrlGenerationError {
#[fail(display="Resource not found")]
#[fail(display = "Resource not found")]
ResourceNotFound,
#[fail(display="Not all path pattern covered")]
#[fail(display = "Not all path pattern covered")]
NotEnoughElements,
#[fail(display="Router is not available")]
#[fail(display = "Router is not available")]
RouterNotAvailable,
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
ParseError(#[cause] UrlParseError),
}
@ -520,8 +528,9 @@ impl From<UrlParseError> for UrlGenerationError {
/// Helper type that can wrap any error and generate custom response.
///
/// In following example any `io::Error` will be converted into "BAD REQUEST" response
/// as opposite to *INNTERNAL SERVER ERROR* which is defined by default.
/// In following example any `io::Error` will be converted into "BAD REQUEST"
/// response as opposite to *INNTERNAL SERVER ERROR* which is defined by
/// default.
///
/// ```rust
/// # extern crate actix_web;
@ -554,7 +563,8 @@ impl<T> InternalError<T> {
}
impl<T> Fail for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
fn backtrace(&self) -> Option<&Backtrace> {
Some(&self.backtrace)
@ -562,7 +572,8 @@ impl<T> Fail for InternalError<T>
}
impl<T> fmt::Debug for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.cause, f)
@ -570,7 +581,8 @@ impl<T> fmt::Debug for InternalError<T>
}
impl<T> fmt::Display for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.cause, f)
@ -578,7 +590,8 @@ impl<T> fmt::Display for InternalError<T>
}
impl<T> ResponseError for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
fn error_response(&self) -> HttpResponse {
HttpResponse::new(self.status)
@ -586,7 +599,8 @@ impl<T> ResponseError for InternalError<T>
}
impl<T> Responder for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
type Item = HttpResponse;
type Error = Error;
@ -596,82 +610,102 @@ impl<T> Responder for InternalError<T>
}
}
/// Helper function that creates wrapper of any error and generate *BAD REQUEST* response.
/// Helper function that creates wrapper of any error and generate *BAD
/// REQUEST* response.
#[allow(non_snake_case)]
pub fn ErrorBadRequest<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::BAD_REQUEST).into()
}
/// Helper function that creates wrapper of any error and generate *UNAUTHORIZED* response.
/// Helper function that creates wrapper of any error and generate
/// *UNAUTHORIZED* response.
#[allow(non_snake_case)]
pub fn ErrorUnauthorized<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::UNAUTHORIZED).into()
}
/// Helper function that creates wrapper of any error and generate *FORBIDDEN* response.
/// Helper function that creates wrapper of any error and generate *FORBIDDEN*
/// response.
#[allow(non_snake_case)]
pub fn ErrorForbidden<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::FORBIDDEN).into()
}
/// Helper function that creates wrapper of any error and generate *NOT FOUND* response.
/// Helper function that creates wrapper of any error and generate *NOT FOUND*
/// response.
#[allow(non_snake_case)]
pub fn ErrorNotFound<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::NOT_FOUND).into()
}
/// Helper function that creates wrapper of any error and generate *METHOD NOT ALLOWED* response.
/// Helper function that creates wrapper of any error and generate *METHOD NOT
/// ALLOWED* response.
#[allow(non_snake_case)]
pub fn ErrorMethodNotAllowed<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into()
}
/// Helper function that creates wrapper of any error and generate *REQUEST TIMEOUT* response.
/// Helper function that creates wrapper of any error and generate *REQUEST
/// TIMEOUT* response.
#[allow(non_snake_case)]
pub fn ErrorRequestTimeout<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into()
}
/// Helper function that creates wrapper of any error and generate *CONFLICT* response.
/// Helper function that creates wrapper of any error and generate *CONFLICT*
/// response.
#[allow(non_snake_case)]
pub fn ErrorConflict<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::CONFLICT).into()
}
/// Helper function that creates wrapper of any error and generate *GONE* response.
/// Helper function that creates wrapper of any error and generate *GONE*
/// response.
#[allow(non_snake_case)]
pub fn ErrorGone<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::GONE).into()
}
/// Helper function that creates wrapper of any error and generate *PRECONDITION FAILED* response.
/// Helper function that creates wrapper of any error and generate
/// *PRECONDITION FAILED* response.
#[allow(non_snake_case)]
pub fn ErrorPreconditionFailed<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::PRECONDITION_FAILED).into()
}
/// Helper function that creates wrapper of any error and generate *EXPECTATION FAILED* response.
/// Helper function that creates wrapper of any error and generate
/// *EXPECTATION FAILED* response.
#[allow(non_snake_case)]
pub fn ErrorExpectationFailed<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::EXPECTATION_FAILED).into()
}
@ -680,26 +714,28 @@ pub fn ErrorExpectationFailed<T>(err: T) -> Error
/// generate *INTERNAL SERVER ERROR* response.
#[allow(non_snake_case)]
pub fn ErrorInternalServerError<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static
where
T: Send + Sync + fmt::Debug + 'static,
{
InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into()
}
#[cfg(test)]
mod tests {
use super::*;
use cookie::ParseError as CookieParseError;
use failure;
use http::{Error as HttpError, StatusCode};
use httparse;
use std::env;
use std::error::Error as StdError;
use std::io;
use httparse;
use http::{StatusCode, Error as HttpError};
use cookie::ParseError as CookieParseError;
use failure;
use super::*;
#[test]
#[cfg(actix_nightly)]
fn test_nightly() {
let resp: HttpResponse = IoError::new(io::ErrorKind::Other, "test").error_response();
let resp: HttpResponse =
IoError::new(io::ErrorKind::Other, "test").error_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
@ -775,10 +811,10 @@ mod tests {
match ParseError::from($from) {
e @ $error => {
assert!(format!("{}", e).len() >= 5);
} ,
e => unreachable!("{:?}", e)
}
e => unreachable!("{:?}", e),
}
}
};
}
macro_rules! from_and_cause {
@ -787,10 +823,10 @@ mod tests {
e @ $error => {
let desc = format!("{}", e.cause().unwrap());
assert_eq!(desc, $from.description().to_owned());
},
_ => unreachable!("{:?}", $from)
}
_ => unreachable!("{:?}", $from),
}
}
};
}
#[test]
@ -814,7 +850,10 @@ mod tests {
env::set_var(NAME, "0");
let error = failure::err_msg("Hello!");
let resp: Error = error.into();
assert_eq!(format!("{:?}", resp), "Compat { error: ErrorMessage { msg: \"Hello!\" } }\n\n");
assert_eq!(
format!("{:?}", resp),
"Compat { error: ErrorMessage { msg: \"Hello!\" } }\n\n"
);
match old_tb {
Ok(x) => env::set_var(NAME, x),
_ => env::remove_var(NAME),

View File

@ -1,19 +1,19 @@
use std::str;
use std::ops::{Deref, DerefMut};
use std::str;
use mime::Mime;
use bytes::Bytes;
use serde_urlencoded;
use serde::de::{self, DeserializeOwned};
use futures::future::{Future, FutureResult, result};
use encoding::all::UTF_8;
use encoding::types::{Encoding, DecoderTrap};
use encoding::types::{DecoderTrap, Encoding};
use futures::future::{result, Future, FutureResult};
use mime::Mime;
use serde::de::{self, DeserializeOwned};
use serde_urlencoded;
use de::PathDeserializer;
use error::{Error, ErrorBadRequest};
use handler::{Either, FromRequest};
use httprequest::HttpRequest;
use httpmessage::{HttpMessage, MessageBody, UrlEncoded};
use de::PathDeserializer;
use httprequest::HttpRequest;
/// Extract typed information from the request's path.
///
@ -39,8 +39,8 @@ use de::PathDeserializer;
/// }
/// ```
///
/// It is possible to extract path information to a specific type that implements
/// `Deserialize` trait from *serde*.
/// It is possible to extract path information to a specific type that
/// implements `Deserialize` trait from *serde*.
///
/// ```rust
/// # extern crate bytes;
@ -65,12 +65,11 @@ use de::PathDeserializer;
/// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor
/// }
/// ```
pub struct Path<T>{
inner: T
pub struct Path<T> {
inner: T,
}
impl<T> AsRef<T> for Path<T> {
fn as_ref(&self) -> &T {
&self.inner
}
@ -98,7 +97,9 @@ impl<T> Path<T> {
}
impl<T, S> FromRequest<S> for Path<T>
where T: DeserializeOwned, S: 'static
where
T: DeserializeOwned,
S: 'static,
{
type Config = ();
type Result = FutureResult<Self, Error>;
@ -106,9 +107,11 @@ impl<T, S> FromRequest<S> for Path<T>
#[inline]
fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result {
let req = req.clone();
result(de::Deserialize::deserialize(PathDeserializer::new(&req))
.map_err(|e| e.into())
.map(|inner| Path{inner}))
result(
de::Deserialize::deserialize(PathDeserializer::new(&req))
.map_err(|e| e.into())
.map(|inner| Path { inner }),
)
}
}
@ -164,7 +167,9 @@ impl<T> Query<T> {
}
impl<T, S> FromRequest<S> for Query<T>
where T: de::DeserializeOwned, S: 'static
where
T: de::DeserializeOwned,
S: 'static,
{
type Config = ();
type Result = FutureResult<Self, Error>;
@ -172,24 +177,26 @@ impl<T, S> FromRequest<S> for Query<T>
#[inline]
fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result {
let req = req.clone();
result(serde_urlencoded::from_str::<T>(req.query_string())
.map_err(|e| e.into())
.map(Query))
result(
serde_urlencoded::from_str::<T>(req.query_string())
.map_err(|e| e.into())
.map(Query),
)
}
}
/// Extract typed information from the request's body.
///
/// To extract typed information from request's body, the type `T` must implement the
/// `Deserialize` trait from *serde*.
/// To extract typed information from request's body, the type `T` must
/// implement the `Deserialize` trait from *serde*.
///
/// [**FormConfig**](dev/struct.FormConfig.html) allows to configure extraction
/// process.
///
/// ## Example
///
/// It is possible to extract path information to a specific type that implements
/// `Deserialize` trait from *serde*.
/// It is possible to extract path information to a specific type that
/// implements `Deserialize` trait from *serde*.
///
/// ```rust
/// # extern crate actix_web;
@ -233,17 +240,21 @@ impl<T> DerefMut for Form<T> {
}
impl<T, S> FromRequest<S> for Form<T>
where T: DeserializeOwned + 'static, S: 'static
where
T: DeserializeOwned + 'static,
S: 'static,
{
type Config = FormConfig;
type Result = Box<Future<Item=Self, Error=Error>>;
type Result = Box<Future<Item = Self, Error = Error>>;
#[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
Box::new(UrlEncoded::new(req.clone())
.limit(cfg.limit)
.from_err()
.map(Form))
Box::new(
UrlEncoded::new(req.clone())
.limit(cfg.limit)
.from_err()
.map(Form),
)
}
}
@ -279,7 +290,6 @@ pub struct FormConfig {
}
impl FormConfig {
/// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit;
@ -289,7 +299,7 @@ impl FormConfig {
impl Default for FormConfig {
fn default() -> Self {
FormConfig{limit: 262_144}
FormConfig { limit: 262_144 }
}
}
@ -297,8 +307,8 @@ impl Default for FormConfig {
///
/// Loads request's payload and construct Bytes instance.
///
/// [**PayloadConfig**](dev/struct.PayloadConfig.html) allows to configure extraction
/// process.
/// [**PayloadConfig**](dev/struct.PayloadConfig.html) allows to configure
/// extraction process.
///
/// ## Example
///
@ -313,11 +323,10 @@ impl Default for FormConfig {
/// }
/// # fn main() {}
/// ```
impl<S: 'static> FromRequest<S> for Bytes
{
impl<S: 'static> FromRequest<S> for Bytes {
type Config = PayloadConfig;
type Result = Either<FutureResult<Self, Error>,
Box<Future<Item=Self, Error=Error>>>;
type Result =
Either<FutureResult<Self, Error>, Box<Future<Item = Self, Error = Error>>>;
#[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
@ -326,9 +335,11 @@ impl<S: 'static> FromRequest<S> for Bytes
return Either::A(result(Err(e)));
}
Either::B(Box::new(MessageBody::new(req.clone())
.limit(cfg.limit)
.from_err()))
Either::B(Box::new(
MessageBody::new(req.clone())
.limit(cfg.limit)
.from_err(),
))
}
}
@ -351,11 +362,10 @@ impl<S: 'static> FromRequest<S> for Bytes
/// }
/// # fn main() {}
/// ```
impl<S: 'static> FromRequest<S> for String
{
impl<S: 'static> FromRequest<S> for String {
type Config = PayloadConfig;
type Result = Either<FutureResult<String, Error>,
Box<Future<Item=String, Error=Error>>>;
type Result =
Either<FutureResult<String, Error>, Box<Future<Item = String, Error = Error>>>;
#[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
@ -366,8 +376,11 @@ impl<S: 'static> FromRequest<S> for String
// check charset
let encoding = match req.encoding() {
Err(_) => return Either::A(
result(Err(ErrorBadRequest("Unknown request charset")))),
Err(_) => {
return Either::A(result(Err(ErrorBadRequest(
"Unknown request charset",
))))
}
Ok(encoding) => encoding,
};
@ -379,13 +392,15 @@ impl<S: 'static> FromRequest<S> for String
let enc: *const Encoding = encoding as *const Encoding;
if enc == UTF_8 {
Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
} else {
Ok(encoding.decode(&body, DecoderTrap::Strict)
.map_err(|_| ErrorBadRequest("Can not decode body"))?)
Ok(encoding
.decode(&body, DecoderTrap::Strict)
.map_err(|_| ErrorBadRequest("Can not decode body"))?)
}
})))
}),
))
}
}
@ -396,14 +411,14 @@ pub struct PayloadConfig {
}
impl PayloadConfig {
/// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit;
self
}
/// Set required mime-type of the request. By default mime type is not enforced.
/// Set required mime-type of the request. By default mime type is not
/// enforced.
pub fn mimetype(&mut self, mt: Mime) -> &mut Self {
self.mimetype = Some(mt);
self
@ -417,13 +432,13 @@ impl PayloadConfig {
if mt != req_mt {
return Err(ErrorBadRequest("Unexpected Content-Type"));
}
},
}
Ok(None) => {
return Err(ErrorBadRequest("Content-Type is expected"));
},
}
Err(err) => {
return Err(err.into());
},
}
}
}
Ok(())
@ -432,21 +447,24 @@ impl PayloadConfig {
impl Default for PayloadConfig {
fn default() -> Self {
PayloadConfig{limit: 262_144, mimetype: None}
PayloadConfig {
limit: 262_144,
mimetype: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use mime;
use bytes::Bytes;
use futures::{Async, Future};
use http::header;
use router::{Router, Resource};
use mime;
use resource::ResourceHandler;
use test::TestRequest;
use router::{Resource, Router};
use server::ServerSettings;
use test::TestRequest;
#[derive(Deserialize, Debug, PartialEq)]
struct Info {
@ -457,12 +475,13 @@ mod tests {
fn test_bytes() {
let cfg = PayloadConfig::default();
let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
match Bytes::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => {
assert_eq!(s, Bytes::from_static(b"hello=world"));
},
}
_ => unreachable!(),
}
}
@ -471,12 +490,13 @@ mod tests {
fn test_string() {
let cfg = PayloadConfig::default();
let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
match String::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => {
assert_eq!(s, "hello=world");
},
}
_ => unreachable!(),
}
}
@ -484,17 +504,19 @@ mod tests {
#[test]
fn test_form() {
let mut req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::CONTENT_LENGTH, "11")
header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "11")
.finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
let mut cfg = FormConfig::default();
cfg.limit(4096);
match Form::<Info>::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => {
assert_eq!(s.hello, "world");
},
}
_ => unreachable!(),
}
}
@ -507,10 +529,13 @@ mod tests {
assert!(cfg.check_mimetype(&req).is_err());
let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded").finish();
header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
).finish();
assert!(cfg.check_mimetype(&req).is_err());
let req = TestRequest::with_header(header::CONTENT_TYPE, "application/json").finish();
let req =
TestRequest::with_header(header::CONTENT_TYPE, "application/json").finish();
assert!(cfg.check_mimetype(&req).is_ok());
}
@ -538,30 +563,39 @@ mod tests {
let mut resource = ResourceHandler::<()>::default();
resource.name("index");
let mut routes = Vec::new();
routes.push((Resource::new("index", "/{key}/{value}/"), Some(resource)));
routes.push((
Resource::new("index", "/{key}/{value}/"),
Some(resource),
));
let (router, _) = Router::new("", ServerSettings::default(), routes);
assert!(router.recognize(&mut req).is_some());
match Path::<MyStruct>::from_request(&req, &()).poll().unwrap() {
match Path::<MyStruct>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => {
assert_eq!(s.key, "name");
assert_eq!(s.value, "user1");
},
}
_ => unreachable!(),
}
match Path::<(String, String)>::from_request(&req, &()).poll().unwrap() {
match Path::<(String, String)>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => {
assert_eq!(s.0, "name");
assert_eq!(s.1, "user1");
},
}
_ => unreachable!(),
}
match Query::<Id>::from_request(&req, &()).poll().unwrap() {
Async::Ready(s) => {
assert_eq!(s.id, "test");
},
}
_ => unreachable!(),
}
@ -572,22 +606,31 @@ mod tests {
Async::Ready(s) => {
assert_eq!(s.as_ref().key, "name");
assert_eq!(s.value, 32);
},
}
_ => unreachable!(),
}
match Path::<(String, u8)>::from_request(&req, &()).poll().unwrap() {
match Path::<(String, u8)>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => {
assert_eq!(s.0, "name");
assert_eq!(s.1, 32);
},
}
_ => unreachable!(),
}
match Path::<Vec<String>>::from_request(&req, &()).poll().unwrap() {
match Path::<Vec<String>>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => {
assert_eq!(s.into_inner(), vec!["name".to_owned(), "32".to_owned()]);
},
assert_eq!(
s.into_inner(),
vec!["name".to_owned(), "32".to_owned()]
);
}
_ => unreachable!(),
}
}
@ -606,7 +649,7 @@ mod tests {
match Path::<i8>::from_request(&req, &()).poll().unwrap() {
Async::Ready(s) => {
assert_eq!(s.into_inner(), 32);
},
}
_ => unreachable!(),
}
}

318
src/fs.rs
View File

@ -1,30 +1,30 @@
//! Static files support
use std::{io, cmp};
use std::io::{Read, Seek};
use std::fmt::Write;
use std::fs::{File, DirEntry, Metadata};
use std::path::{Path, PathBuf};
use std::fs::{DirEntry, File, Metadata};
use std::io::{Read, Seek};
use std::ops::{Deref, DerefMut};
use std::time::{SystemTime, UNIX_EPOCH};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use std::{cmp, io};
#[cfg(unix)]
use std::os::unix::fs::MetadataExt;
use bytes::{Bytes, BytesMut, BufMut};
use futures::{Async, Poll, Future, Stream};
use futures_cpupool::{CpuPool, CpuFuture};
use bytes::{BufMut, Bytes, BytesMut};
use futures::{Async, Future, Poll, Stream};
use futures_cpupool::{CpuFuture, CpuPool};
use mime_guess::get_mime_type;
use percent_encoding::percent_decode;
use header;
use error::Error;
use param::FromParam;
use handler::{Handler, RouteHandler, WrapHandler, Responder, Reply};
use handler::{Handler, Reply, Responder, RouteHandler, WrapHandler};
use header;
use http::{Method, StatusCode};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use param::FromParam;
/// A file with an associated name; responds with the Content-Type based on the
/// file extension.
@ -55,9 +55,15 @@ impl NamedFile {
let path = path.as_ref().to_path_buf();
let modified = md.modified().ok();
let cpu_pool = None;
Ok(NamedFile{path, file, md, modified, cpu_pool,
only_get: false,
status_code: StatusCode::OK})
Ok(NamedFile {
path,
file,
md,
modified,
cpu_pool,
only_get: false,
status_code: StatusCode::OK,
})
}
/// Allow only GET and HEAD methods
@ -110,17 +116,25 @@ impl NamedFile {
self.modified.as_ref().map(|mtime| {
let ino = {
#[cfg(unix)]
{ self.md.ino() }
{
self.md.ino()
}
#[cfg(not(unix))]
{ 0 }
{
0
}
};
let dur = mtime.duration_since(UNIX_EPOCH)
let dur = mtime
.duration_since(UNIX_EPOCH)
.expect("modification time must be after epoch");
header::EntityTag::strong(
format!("{:x}:{:x}:{:x}:{:x}",
ino, self.md.len(), dur.as_secs(),
dur.subsec_nanos()))
header::EntityTag::strong(format!(
"{:x}:{:x}:{:x}:{:x}",
ino,
self.md.len(),
dur.as_secs(),
dur.subsec_nanos()
))
})
}
@ -178,7 +192,6 @@ fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
}
}
impl Responder for NamedFile {
type Item = HttpResponse;
type Error = io::Error;
@ -187,23 +200,27 @@ impl Responder for NamedFile {
if self.status_code != StatusCode::OK {
let mut resp = HttpResponse::build(self.status_code);
resp.if_some(self.path().extension(), |ext, resp| {
resp.set(header::ContentType(get_mime_type(&ext.to_string_lossy())));
resp.set(header::ContentType(get_mime_type(
&ext.to_string_lossy(),
)));
});
let reader = ChunkedReadFile {
size: self.md.len(),
offset: 0,
cpu_pool: self.cpu_pool.unwrap_or_else(|| req.cpu_pool().clone()),
cpu_pool: self.cpu_pool
.unwrap_or_else(|| req.cpu_pool().clone()),
file: Some(self.file),
fut: None,
};
return Ok(resp.streaming(reader))
return Ok(resp.streaming(reader));
}
if self.only_get && *req.method() != Method::GET && *req.method() != Method::HEAD {
if self.only_get && *req.method() != Method::GET && *req.method() != Method::HEAD
{
return Ok(HttpResponse::MethodNotAllowed()
.header(header::CONTENT_TYPE, "text/plain")
.header(header::ALLOW, "GET, HEAD")
.body("This resource only supports GET and HEAD."))
.header(header::CONTENT_TYPE, "text/plain")
.header(header::ALLOW, "GET, HEAD")
.body("This resource only supports GET and HEAD."));
}
let etag = self.etag();
@ -233,17 +250,21 @@ impl Responder for NamedFile {
let mut resp = HttpResponse::build(self.status_code);
resp
.if_some(self.path().extension(), |ext, resp| {
resp.set(header::ContentType(get_mime_type(&ext.to_string_lossy())));
resp.if_some(self.path().extension(), |ext, resp| {
resp.set(header::ContentType(get_mime_type(
&ext.to_string_lossy(),
)));
}).if_some(last_modified, |lm, resp| {
resp.set(header::LastModified(lm));
})
.if_some(last_modified, |lm, resp| {resp.set(header::LastModified(lm));})
.if_some(etag, |etag, resp| {resp.set(header::ETag(etag));});
.if_some(etag, |etag, resp| {
resp.set(header::ETag(etag));
});
if precondition_failed {
return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish())
return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish());
} else if not_modified {
return Ok(resp.status(StatusCode::NOT_MODIFIED).finish())
return Ok(resp.status(StatusCode::NOT_MODIFIED).finish());
}
if *req.method() == Method::HEAD {
@ -252,7 +273,8 @@ impl Responder for NamedFile {
let reader = ChunkedReadFile {
size: self.md.len(),
offset: 0,
cpu_pool: self.cpu_pool.unwrap_or_else(|| req.cpu_pool().clone()),
cpu_pool: self.cpu_pool
.unwrap_or_else(|| req.cpu_pool().clone()),
file: Some(self.file),
fut: None,
};
@ -273,7 +295,7 @@ pub struct ChunkedReadFile {
impl Stream for ChunkedReadFile {
type Item = Bytes;
type Error= Error;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
if self.fut.is_some() {
@ -283,7 +305,7 @@ impl Stream for ChunkedReadFile {
self.file = Some(file);
self.offset += bytes.len() as u64;
Ok(Async::Ready(Some(bytes)))
},
}
Async::NotReady => Ok(Async::NotReady),
};
}
@ -299,11 +321,11 @@ impl Stream for ChunkedReadFile {
let max_bytes = cmp::min(size.saturating_sub(offset), 65_536) as usize;
let mut buf = BytesMut::with_capacity(max_bytes);
file.seek(io::SeekFrom::Start(offset))?;
let nbytes = file.read(unsafe{buf.bytes_mut()})?;
let nbytes = file.read(unsafe { buf.bytes_mut() })?;
if nbytes == 0 {
return Err(io::ErrorKind::UnexpectedEof.into())
return Err(io::ErrorKind::UnexpectedEof.into());
}
unsafe{buf.advance_mut(nbytes)};
unsafe { buf.advance_mut(nbytes) };
Ok((file, buf.freeze()))
}));
self.poll()
@ -313,9 +335,9 @@ impl Stream for ChunkedReadFile {
/// A directory; responds with the generated directory listing.
#[derive(Debug)]
pub struct Directory{
pub struct Directory {
base: PathBuf,
path: PathBuf
path: PathBuf,
}
impl Directory {
@ -327,12 +349,12 @@ impl Directory {
if let Ok(ref entry) = *entry {
if let Some(name) = entry.file_name().to_str() {
if name.starts_with('.') {
return false
return false;
}
}
if let Ok(ref md) = entry.metadata() {
let ft = md.file_type();
return ft.is_dir() || ft.is_file() || ft.is_symlink()
return ft.is_dir() || ft.is_file() || ft.is_symlink();
}
}
false
@ -353,7 +375,7 @@ impl Responder for Directory {
let entry = entry.unwrap();
let p = match entry.path().strip_prefix(&self.path) {
Ok(p) => base.join(p),
Err(_) => continue
Err(_) => continue,
};
// show file url as relative to static path
let file_url = format!("{}", p.to_string_lossy());
@ -361,27 +383,38 @@ impl Responder for Directory {
// if file is a directory, add '/' to the end of the name
if let Ok(metadata) = entry.metadata() {
if metadata.is_dir() {
let _ = write!(body, "<li><a href=\"{}\">{}/</a></li>",
file_url, entry.file_name().to_string_lossy());
let _ = write!(
body,
"<li><a href=\"{}\">{}/</a></li>",
file_url,
entry.file_name().to_string_lossy()
);
} else {
let _ = write!(body, "<li><a href=\"{}\">{}</a></li>",
file_url, entry.file_name().to_string_lossy());
let _ = write!(
body,
"<li><a href=\"{}\">{}</a></li>",
file_url,
entry.file_name().to_string_lossy()
);
}
} else {
continue
continue;
}
}
}
let html = format!("<html>\
<head><title>{}</title></head>\
<body><h1>{}</h1>\
<ul>\
{}\
</ul></body>\n</html>", index_of, index_of, body);
let html = format!(
"<html>\
<head><title>{}</title></head>\
<body><h1>{}</h1>\
<ul>\
{}\
</ul></body>\n</html>",
index_of, index_of, body
);
Ok(HttpResponse::Ok()
.content_type("text/html; charset=utf-8")
.body(html))
.content_type("text/html; charset=utf-8")
.body(html))
}
}
@ -411,12 +444,11 @@ pub struct StaticFiles<S> {
_follow_symlinks: bool,
}
lazy_static!{
lazy_static! {
static ref DEFAULT_CPUPOOL: Mutex<CpuPool> = Mutex::new(CpuPool::new(20));
}
impl<S: 'static> StaticFiles<S> {
/// Create new `StaticFiles` instance for specified base directory.
pub fn new<T: Into<PathBuf>>(dir: T) -> StaticFiles<S> {
let dir = dir.into();
@ -429,7 +461,7 @@ impl<S: 'static> StaticFiles<S> {
warn!("Is not directory `{:?}`", dir);
(dir, false)
}
},
}
Err(err) => {
warn!("Static files directory `{:?}` error: {}", dir, err);
(dir, false)
@ -437,9 +469,7 @@ impl<S: 'static> StaticFiles<S> {
};
// use default CpuPool
let pool = {
DEFAULT_CPUPOOL.lock().unwrap().clone()
};
let pool = { DEFAULT_CPUPOOL.lock().unwrap().clone() };
StaticFiles {
directory: dir,
@ -447,8 +477,9 @@ impl<S: 'static> StaticFiles<S> {
index: None,
show_index: false,
cpu_pool: pool,
default: Box::new(WrapHandler::new(
|_| HttpResponse::new(StatusCode::NOT_FOUND))),
default: Box::new(WrapHandler::new(|_| {
HttpResponse::new(StatusCode::NOT_FOUND)
})),
_chunk_size: 0,
_follow_symlinks: false,
}
@ -485,12 +516,13 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
if !self.accessible {
Ok(self.default.handle(req))
} else {
let relpath = match req.match_info().get("tail").map(
|tail| percent_decode(tail.as_bytes()).decode_utf8().unwrap())
let relpath = match req.match_info()
.get("tail")
.map(|tail| percent_decode(tail.as_bytes()).decode_utf8().unwrap())
.map(|tail| PathBuf::from_param(tail.as_ref()))
{
Some(Ok(path)) => path,
_ => return Ok(self.default.handle(req))
_ => return Ok(self.default.handle(req)),
};
// full filepath
@ -499,7 +531,8 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
if path.is_dir() {
if let Some(ref redir_index) = self.index {
// TODO: Don't redirect, just return the index content.
// TODO: It'd be nice if there were a good usable URL manipulation library
// TODO: It'd be nice if there were a good usable URL manipulation
// library
let mut new_path: String = req.path().to_owned();
for el in relpath.iter() {
new_path.push_str(&el.to_string_lossy());
@ -516,14 +549,15 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
} else if self.show_index {
Directory::new(self.directory.clone(), path)
.respond_to(req.drop_state())?
.respond_to(req.drop_state())
.respond_to(req.drop_state())
} else {
Ok(self.default.handle(req))
}
} else {
NamedFile::open(path)?.set_cpu_pool(self.cpu_pool.clone())
NamedFile::open(path)?
.set_cpu_pool(self.cpu_pool.clone())
.respond_to(req.drop_state())?
.respond_to(req.drop_state())
.respond_to(req.drop_state())
}
}
}
@ -533,33 +567,49 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
mod tests {
use super::*;
use application::App;
use test::{self, TestRequest};
use http::{header, Method, StatusCode};
use test::{self, TestRequest};
#[test]
fn test_named_file() {
assert!(NamedFile::open("test--").is_err());
let mut file = NamedFile::open("Cargo.toml").unwrap()
let mut file = NamedFile::open("Cargo.toml")
.unwrap()
.set_cpu_pool(CpuPool::new(1));
{ file.file();
let _f: &File = &file; }
{ let _f: &mut File = &mut file; }
{
file.file();
let _f: &File = &file;
}
{
let _f: &mut File = &mut file;
}
let resp = file.respond_to(HttpRequest::default()).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml")
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"text/x-toml"
)
}
#[test]
fn test_named_file_status_code() {
let mut file = NamedFile::open("Cargo.toml").unwrap()
let mut file = NamedFile::open("Cargo.toml")
.unwrap()
.set_status_code(StatusCode::NOT_FOUND)
.set_cpu_pool(CpuPool::new(1));
{ file.file();
let _f: &File = &file; }
{ let _f: &mut File = &mut file; }
{
file.file();
let _f: &File = &file;
}
{
let _f: &mut File = &mut file;
}
let resp = file.respond_to(HttpRequest::default()).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml");
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"text/x-toml"
);
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
@ -584,13 +634,17 @@ mod tests {
fn test_static_files() {
let mut st = StaticFiles::new(".").show_files_listing();
st.accessible = false;
let resp = st.handle(HttpRequest::default()).respond_to(HttpRequest::default()).unwrap();
let resp = st.handle(HttpRequest::default())
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
st.accessible = true;
st.show_index = false;
let resp = st.handle(HttpRequest::default()).respond_to(HttpRequest::default()).unwrap();
let resp = st.handle(HttpRequest::default())
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
@ -598,9 +652,14 @@ mod tests {
req.match_info_mut().add("tail", "");
st.show_index = true;
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap();
let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/html; charset=utf-8");
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"text/html; charset=utf-8"
);
assert!(resp.body().is_binary());
assert!(format!("{:?}", resp.body()).contains("README.md"));
}
@ -611,18 +670,28 @@ mod tests {
let mut req = HttpRequest::default();
req.match_info_mut().add("tail", "guide");
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap();
let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::FOUND);
assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/guide/index.html");
assert_eq!(
resp.headers().get(header::LOCATION).unwrap(),
"/guide/index.html"
);
let mut req = HttpRequest::default();
req.match_info_mut().add("tail", "guide/");
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap();
let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::FOUND);
assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/guide/index.html");
assert_eq!(
resp.headers().get(header::LOCATION).unwrap(),
"/guide/index.html"
);
}
#[test]
@ -631,58 +700,87 @@ mod tests {
let mut req = HttpRequest::default();
req.match_info_mut().add("tail", "tools/wsload");
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap();
let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::FOUND);
assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/tools/wsload/Cargo.toml");
assert_eq!(
resp.headers().get(header::LOCATION).unwrap(),
"/tools/wsload/Cargo.toml"
);
}
#[test]
fn integration_redirect_to_index_with_prefix() {
let mut srv = test::TestServer::with_factory(
|| App::new()
let mut srv = test::TestServer::with_factory(|| {
App::new()
.prefix("public")
.handler("/", StaticFiles::new(".").index_file("Cargo.toml")));
.handler("/", StaticFiles::new(".").index_file("Cargo.toml"))
});
let request = srv.get().uri(srv.url("/public")).finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap();
let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/public/Cargo.toml");
let request = srv.get().uri(srv.url("/public/")).finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap();
let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/public/Cargo.toml");
}
#[test]
fn integration_redirect_to_index() {
let mut srv = test::TestServer::with_factory(
|| App::new()
.handler("test", StaticFiles::new(".").index_file("Cargo.toml")));
let mut srv = test::TestServer::with_factory(|| {
App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
});
let request = srv.get().uri(srv.url("/test")).finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap();
let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/test/Cargo.toml");
let request = srv.get().uri(srv.url("/test/")).finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap();
let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/test/Cargo.toml");
}
#[test]
fn integration_percent_encoded() {
let mut srv = test::TestServer::with_factory(
|| App::new()
.handler("test", StaticFiles::new(".").index_file("Cargo.toml")));
let mut srv = test::TestServer::with_factory(|| {
App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
});
let request = srv.get().uri(srv.url("/test/%43argo.toml")).finish().unwrap();
let request = srv.get()
.uri(srv.url("/test/%43argo.toml"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::OK);
}

View File

@ -1,7 +1,7 @@
use std::ops::Deref;
use std::marker::PhantomData;
use futures::Poll;
use futures::future::{Future, FutureResult, ok, err};
use futures::future::{err, ok, Future, FutureResult};
use std::marker::PhantomData;
use std::ops::Deref;
use error::Error;
use httprequest::HttpRequest;
@ -10,7 +10,6 @@ use httpresponse::HttpResponse;
/// Trait defines object that could be registered as route handler
#[allow(unused_variables)]
pub trait Handler<S>: 'static {
/// The type of value that handler will return.
type Result: Responder;
@ -35,13 +34,15 @@ pub trait Responder {
/// Trait implemented by types that can be extracted from request.
///
/// Types that implement this trait can be used with `Route::with()` method.
pub trait FromRequest<S>: Sized where S: 'static
pub trait FromRequest<S>: Sized
where
S: 'static,
{
/// Configuration for conversion process
type Config: Default;
/// Future that resolves to a Self
type Result: Future<Item=Self, Error=Error>;
type Result: Future<Item = Self, Error = Error>;
/// Convert request to a Self
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result;
@ -83,7 +84,9 @@ pub enum Either<A, B> {
}
impl<A, B> Responder for Either<A, B>
where A: Responder, B: Responder
where
A: Responder,
B: Responder,
{
type Item = Reply;
type Error = Error;
@ -103,8 +106,9 @@ impl<A, B> Responder for Either<A, B>
}
impl<A, B, I, E> Future for Either<A, B>
where A: Future<Item=I, Error=E>,
B: Future<Item=I, Error=E>,
where
A: Future<Item = I, Error = E>,
B: Future<Item = I, Error = E>,
{
type Item = I;
type Error = E;
@ -146,23 +150,25 @@ impl<A, B, I, E> Future for Either<A, B>
/// # fn main() {}
/// ```
pub trait AsyncResponder<I, E>: Sized {
fn responder(self) -> Box<Future<Item=I, Error=E>>;
fn responder(self) -> Box<Future<Item = I, Error = E>>;
}
impl<F, I, E> AsyncResponder<I, E> for F
where F: Future<Item=I, Error=E> + 'static,
I: Responder + 'static,
E: Into<Error> + 'static,
where
F: Future<Item = I, Error = E> + 'static,
I: Responder + 'static,
E: Into<Error> + 'static,
{
fn responder(self) -> Box<Future<Item=I, Error=E>> {
fn responder(self) -> Box<Future<Item = I, Error = E>> {
Box::new(self)
}
}
/// Handler<S> for Fn()
impl<F, R, S> Handler<S> for F
where F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static
where
F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
{
type Result = R;
@ -176,15 +182,15 @@ pub struct Reply(ReplyItem);
pub(crate) enum ReplyItem {
Message(HttpResponse),
Future(Box<Future<Item=HttpResponse, Error=Error>>),
Future(Box<Future<Item = HttpResponse, Error = Error>>),
}
impl Reply {
/// Create async response
#[inline]
pub fn async<F>(fut: F) -> Reply
where F: Future<Item=HttpResponse, Error=Error> + 'static
where
F: Future<Item = HttpResponse, Error = Error> + 'static,
{
Reply(ReplyItem::Future(Box::new(fut)))
}
@ -229,15 +235,13 @@ impl Responder for HttpResponse {
}
impl From<HttpResponse> for Reply {
#[inline]
fn from(resp: HttpResponse) -> Reply {
Reply(ReplyItem::Message(resp))
}
}
impl<T: Responder, E: Into<Error>> Responder for Result<T, E>
{
impl<T: Responder, E: Into<Error>> Responder for Result<T, E> {
type Item = <T as Responder>::Item;
type Error = Error;
@ -272,19 +276,20 @@ impl<E: Into<Error>> From<Result<HttpResponse, E>> for Reply {
}
}
impl From<Box<Future<Item=HttpResponse, Error=Error>>> for Reply {
impl From<Box<Future<Item = HttpResponse, Error = Error>>> for Reply {
#[inline]
fn from(fut: Box<Future<Item=HttpResponse, Error=Error>>) -> Reply {
fn from(fut: Box<Future<Item = HttpResponse, Error = Error>>) -> Reply {
Reply(ReplyItem::Future(fut))
}
}
/// Convenience type alias
pub type FutureResponse<I, E=Error> = Box<Future<Item=I, Error=E>>;
pub type FutureResponse<I, E = Error> = Box<Future<Item = I, Error = E>>;
impl<I, E> Responder for Box<Future<Item=I, Error=E>>
where I: Responder + 'static,
E: Into<Error> + 'static
impl<I, E> Responder for Box<Future<Item = I, Error = E>>
where
I: Responder + 'static,
E: Into<Error> + 'static,
{
type Item = Reply;
type Error = Error;
@ -292,14 +297,12 @@ impl<I, E> Responder for Box<Future<Item=I, Error=E>>
#[inline]
fn respond_to(self, req: HttpRequest) -> Result<Reply, Error> {
let fut = self.map_err(|e| e.into())
.then(move |r| {
match r.respond_to(req) {
Ok(reply) => match reply.into().0 {
ReplyItem::Message(resp) => ok(resp),
_ => panic!("Nested async replies are not supported"),
},
Err(e) => err(e),
}
.then(move |r| match r.respond_to(req) {
Ok(reply) => match reply.into().0 {
ReplyItem::Message(resp) => ok(resp),
_ => panic!("Nested async replies are not supported"),
},
Err(e) => err(e),
});
Ok(Reply::async(fut))
}
@ -311,30 +314,35 @@ pub(crate) trait RouteHandler<S>: 'static {
}
/// Route handler wrapper for Handler
pub(crate)
struct WrapHandler<S, H, R>
where H: Handler<S, Result=R>,
R: Responder,
S: 'static,
pub(crate) struct WrapHandler<S, H, R>
where
H: Handler<S, Result = R>,
R: Responder,
S: 'static,
{
h: H,
s: PhantomData<S>,
}
impl<S, H, R> WrapHandler<S, H, R>
where H: Handler<S, Result=R>,
R: Responder,
S: 'static,
where
H: Handler<S, Result = R>,
R: Responder,
S: 'static,
{
pub fn new(h: H) -> Self {
WrapHandler{h, s: PhantomData}
WrapHandler {
h,
s: PhantomData,
}
}
}
impl<S, H, R> RouteHandler<S> for WrapHandler<S, H, R>
where H: Handler<S, Result=R>,
R: Responder + 'static,
S: 'static,
where
H: Handler<S, Result = R>,
R: Responder + 'static,
S: 'static,
{
fn handle(&mut self, req: HttpRequest<S>) -> Reply {
let req2 = req.drop_state();
@ -346,50 +354,53 @@ impl<S, H, R> RouteHandler<S> for WrapHandler<S, H, R>
}
/// Async route handler
pub(crate)
struct AsyncHandler<S, H, F, R, E>
where H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
S: 'static,
pub(crate) struct AsyncHandler<S, H, F, R, E>
where
H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item = R, Error = E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
S: 'static,
{
h: Box<H>,
s: PhantomData<S>,
}
impl<S, H, F, R, E> AsyncHandler<S, H, F, R, E>
where H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
S: 'static,
where
H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item = R, Error = E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
S: 'static,
{
pub fn new(h: H) -> Self {
AsyncHandler{h: Box::new(h), s: PhantomData}
AsyncHandler {
h: Box::new(h),
s: PhantomData,
}
}
}
impl<S, H, F, R, E> RouteHandler<S> for AsyncHandler<S, H, F, R, E>
where H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
S: 'static,
where
H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item = R, Error = E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
S: 'static,
{
fn handle(&mut self, req: HttpRequest<S>) -> Reply {
let req2 = req.drop_state();
let fut = (self.h)(req)
.map_err(|e| e.into())
.then(move |r| {
match r.respond_to(req2) {
Ok(reply) => match reply.into().0 {
ReplyItem::Message(resp) => ok(resp),
_ => panic!("Nested async replies are not supported"),
},
Err(e) => err(e),
}
});
let fut = (self.h)(req).map_err(|e| e.into()).then(move |r| {
match r.respond_to(req2) {
Ok(reply) => match reply.into().0 {
ReplyItem::Message(resp) => ok(resp),
_ => panic!("Nested async replies are not supported"),
},
Err(e) => err(e),
}
});
Reply::async(fut)
}
}
@ -426,7 +437,7 @@ impl<S, H, F, R, E> RouteHandler<S> for AsyncHandler<S, H, F, R, E>
/// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor
/// }
/// ```
pub struct State<S> (HttpRequest<S>);
pub struct State<S>(HttpRequest<S>);
impl<S> Deref for State<S> {
type Target = S;
@ -436,8 +447,7 @@ impl<S> Deref for State<S> {
}
}
impl<S: 'static> FromRequest<S> for State<S>
{
impl<S: 'static> FromRequest<S> for State<S> {
type Config = ();
type Result = FutureResult<Self, Error>;

View File

@ -1,6 +1,6 @@
use mime::{self, Mime};
use header::{QualityItem, qitem};
use header::{qitem, QualityItem};
use http::header as http;
use mime::{self, Mime};
header! {
/// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2)

View File

@ -1,4 +1,4 @@
use header::{ACCEPT_CHARSET, Charset, QualityItem};
use header::{Charset, QualityItem, ACCEPT_CHARSET};
header! {
/// `Accept-Charset` header, defined in

View File

@ -1,5 +1,5 @@
use header::{QualityItem, ACCEPT_LANGUAGE};
use language_tags::LanguageTag;
use header::{ACCEPT_LANGUAGE, QualityItem};
header! {
/// `Accept-Language` header, defined in

View File

@ -1,8 +1,8 @@
use header::{Header, IntoHeaderValue, Writer};
use header::{fmt_comma_delimited, from_comma_delimited};
use http::header;
use std::fmt::{self, Write};
use std::str::FromStr;
use http::header;
use header::{Header, IntoHeaderValue, Writer};
use header::{from_comma_delimited, fmt_comma_delimited};
/// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2)
///
@ -30,9 +30,7 @@ use header::{from_comma_delimited, fmt_comma_delimited};
/// use actix_web::http::header::{CacheControl, CacheDirective};
///
/// let mut builder = HttpResponse::Ok();
/// builder.set(
/// CacheControl(vec![CacheDirective::MaxAge(86400u32)])
/// );
/// builder.set(CacheControl(vec![CacheDirective::MaxAge(86400u32)]));
/// ```
///
/// ```rust
@ -40,15 +38,12 @@ use header::{from_comma_delimited, fmt_comma_delimited};
/// use actix_web::http::header::{CacheControl, CacheDirective};
///
/// let mut builder = HttpResponse::Ok();
/// builder.set(
/// CacheControl(vec![
/// CacheDirective::NoCache,
/// CacheDirective::Private,
/// CacheDirective::MaxAge(360u32),
/// CacheDirective::Extension("foo".to_owned(),
/// Some("bar".to_owned())),
/// ])
/// );
/// builder.set(CacheControl(vec![
/// CacheDirective::NoCache,
/// CacheDirective::Private,
/// CacheDirective::MaxAge(360u32),
/// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())),
/// ]));
/// ```
#[derive(PartialEq, Clone, Debug)]
pub struct CacheControl(pub Vec<CacheDirective>);
@ -63,7 +58,8 @@ impl Header for CacheControl {
#[inline]
fn parse<T>(msg: &T) -> Result<Self, ::error::ParseError>
where T: ::HttpMessage
where
T: ::HttpMessage,
{
let directives = from_comma_delimited(msg.headers().get_all(Self::name()))?;
if !directives.is_empty() {
@ -123,32 +119,36 @@ pub enum CacheDirective {
SMaxAge(u32),
/// Extension directives. Optionally include an argument.
Extension(String, Option<String>)
Extension(String, Option<String>),
}
impl fmt::Display for CacheDirective {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::CacheDirective::*;
fmt::Display::fmt(match *self {
NoCache => "no-cache",
NoStore => "no-store",
NoTransform => "no-transform",
OnlyIfCached => "only-if-cached",
fmt::Display::fmt(
match *self {
NoCache => "no-cache",
NoStore => "no-store",
NoTransform => "no-transform",
OnlyIfCached => "only-if-cached",
MaxAge(secs) => return write!(f, "max-age={}", secs),
MaxStale(secs) => return write!(f, "max-stale={}", secs),
MinFresh(secs) => return write!(f, "min-fresh={}", secs),
MaxAge(secs) => return write!(f, "max-age={}", secs),
MaxStale(secs) => return write!(f, "max-stale={}", secs),
MinFresh(secs) => return write!(f, "min-fresh={}", secs),
MustRevalidate => "must-revalidate",
Public => "public",
Private => "private",
ProxyRevalidate => "proxy-revalidate",
SMaxAge(secs) => return write!(f, "s-maxage={}", secs),
MustRevalidate => "must-revalidate",
Public => "public",
Private => "private",
ProxyRevalidate => "proxy-revalidate",
SMaxAge(secs) => return write!(f, "s-maxage={}", secs),
Extension(ref name, None) => &name[..],
Extension(ref name, Some(ref arg)) => return write!(f, "{}={}", name, arg),
}, f)
Extension(ref name, None) => &name[..],
Extension(ref name, Some(ref arg)) => {
return write!(f, "{}={}", name, arg)
}
},
f,
)
}
}
@ -167,16 +167,20 @@ impl FromStr for CacheDirective {
"proxy-revalidate" => Ok(ProxyRevalidate),
"" => Err(None),
_ => match s.find('=') {
Some(idx) if idx+1 < s.len() => match (&s[..idx], (&s[idx+1..]).trim_matches('"')) {
("max-age" , secs) => secs.parse().map(MaxAge).map_err(Some),
("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some),
("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some),
("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some),
(left, right) => Ok(Extension(left.to_owned(), Some(right.to_owned())))
},
Some(idx) if idx + 1 < s.len() => {
match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) {
("max-age", secs) => secs.parse().map(MaxAge).map_err(Some),
("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some),
("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some),
("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some),
(left, right) => {
Ok(Extension(left.to_owned(), Some(right.to_owned())))
}
}
}
Some(_) => Err(None),
None => Ok(Extension(s.to_owned(), None))
}
None => Ok(Extension(s.to_owned(), None)),
},
}
}
}
@ -189,38 +193,56 @@ mod tests {
#[test]
fn test_parse_multiple_headers() {
let req = TestRequest::with_header(
header::CACHE_CONTROL, "no-cache, private").finish();
let req = TestRequest::with_header(header::CACHE_CONTROL, "no-cache, private")
.finish();
let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::NoCache,
CacheDirective::Private])))
assert_eq!(
cache.ok(),
Some(CacheControl(vec![
CacheDirective::NoCache,
CacheDirective::Private,
]))
)
}
#[test]
fn test_parse_argument() {
let req = TestRequest::with_header(
header::CACHE_CONTROL, "max-age=100, private").finish();
let req =
TestRequest::with_header(header::CACHE_CONTROL, "max-age=100, private")
.finish();
let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::MaxAge(100),
CacheDirective::Private])))
assert_eq!(
cache.ok(),
Some(CacheControl(vec![
CacheDirective::MaxAge(100),
CacheDirective::Private,
]))
)
}
#[test]
fn test_parse_quote_form() {
let req = TestRequest::with_header(
header::CACHE_CONTROL, "max-age=\"200\"").finish();
let req =
TestRequest::with_header(header::CACHE_CONTROL, "max-age=\"200\"").finish();
let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::MaxAge(200)])))
assert_eq!(
cache.ok(),
Some(CacheControl(vec![CacheDirective::MaxAge(200)]))
)
}
#[test]
fn test_parse_extension() {
let req = TestRequest::with_header(
header::CACHE_CONTROL, "foo, bar=baz").finish();
let req =
TestRequest::with_header(header::CACHE_CONTROL, "foo, bar=baz").finish();
let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![
CacheDirective::Extension("foo".to_owned(), None),
CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned()))])))
assert_eq!(
cache.ok(),
Some(CacheControl(vec![
CacheDirective::Extension("foo".to_owned(), None),
CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())),
]))
)
}
#[test]

View File

@ -1,21 +1,21 @@
use header::{QualityItem, CONTENT_LANGUAGE};
use language_tags::LanguageTag;
use header::{CONTENT_LANGUAGE, QualityItem};
header! {
/// `Content-Language` header, defined in
/// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2)
///
///
/// The `Content-Language` header field describes the natural language(s)
/// of the intended audience for the representation. Note that this
/// might not be equivalent to all the languages used within the
/// representation.
///
///
/// # ABNF
///
/// ```text
/// Content-Language = 1#language-tag
/// ```
///
///
/// # Example values
///
/// * `da`
@ -28,7 +28,7 @@ header! {
/// # #[macro_use] extern crate language_tags;
/// use actix_web::HttpResponse;
/// # use actix_web::http::header::{ContentLanguage, qitem};
/// #
/// #
/// # fn main() {
/// let mut builder = HttpResponse::Ok();
/// builder.set(
@ -46,7 +46,7 @@ header! {
/// # use actix_web::http::header::{ContentLanguage, qitem};
/// #
/// # fn main() {
///
///
/// let mut builder = HttpResponse::Ok();
/// builder.set(
/// ContentLanguage(vec![

View File

@ -1,8 +1,8 @@
use error::ParseError;
use header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer,
CONTENT_RANGE};
use std::fmt::{self, Display, Write};
use std::str::FromStr;
use error::ParseError;
use header::{IntoHeaderValue, Writer,
HeaderValue, InvalidHeaderValueBytes, CONTENT_RANGE};
header! {
/// `Content-Range` header, defined in
@ -69,7 +69,6 @@ header! {
}
}
/// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2)
///
/// # ABNF
@ -99,7 +98,7 @@ pub enum ContentRangeSpec {
range: Option<(u64, u64)>,
/// Total length of the instance, can be omitted if unknown
instance_length: Option<u64>
instance_length: Option<u64>,
},
/// Custom range, with unit not registered at IANA
@ -108,15 +107,15 @@ pub enum ContentRangeSpec {
unit: String,
/// other-range-resp
resp: String
}
resp: String,
},
}
fn split_in_two(s: &str, separator: char) -> Option<(&str, &str)> {
let mut iter = s.splitn(2, separator);
match (iter.next(), iter.next()) {
(Some(a), Some(b)) => Some((a, b)),
_ => None
_ => None,
}
}
@ -126,40 +125,40 @@ impl FromStr for ContentRangeSpec {
fn from_str(s: &str) -> Result<Self, ParseError> {
let res = match split_in_two(s, ' ') {
Some(("bytes", resp)) => {
let (range, instance_length) = split_in_two(
resp, '/').ok_or(ParseError::Header)?;
let (range, instance_length) =
split_in_two(resp, '/').ok_or(ParseError::Header)?;
let instance_length = if instance_length == "*" {
None
} else {
Some(instance_length.parse()
.map_err(|_| ParseError::Header)?)
Some(instance_length
.parse()
.map_err(|_| ParseError::Header)?)
};
let range = if range == "*" {
None
} else {
let (first_byte, last_byte) = split_in_two(
range, '-').ok_or(ParseError::Header)?;
let first_byte = first_byte.parse()
.map_err(|_| ParseError::Header)?;
let last_byte = last_byte.parse()
.map_err(|_| ParseError::Header)?;
let (first_byte, last_byte) =
split_in_two(range, '-').ok_or(ParseError::Header)?;
let first_byte = first_byte.parse().map_err(|_| ParseError::Header)?;
let last_byte = last_byte.parse().map_err(|_| ParseError::Header)?;
if last_byte < first_byte {
return Err(ParseError::Header);
}
Some((first_byte, last_byte))
};
ContentRangeSpec::Bytes {range, instance_length}
}
Some((unit, resp)) => {
ContentRangeSpec::Unregistered {
unit: unit.to_owned(),
resp: resp.to_owned()
ContentRangeSpec::Bytes {
range,
instance_length,
}
}
_ => return Err(ParseError::Header)
Some((unit, resp)) => ContentRangeSpec::Unregistered {
unit: unit.to_owned(),
resp: resp.to_owned(),
},
_ => return Err(ParseError::Header),
};
Ok(res)
}
@ -168,12 +167,15 @@ impl FromStr for ContentRangeSpec {
impl Display for ContentRangeSpec {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ContentRangeSpec::Bytes { range, instance_length } => {
ContentRangeSpec::Bytes {
range,
instance_length,
} => {
try!(f.write_str("bytes "));
match range {
Some((first_byte, last_byte)) => {
try!(write!(f, "{}-{}", first_byte, last_byte));
},
}
None => {
try!(f.write_str("*"));
}
@ -185,7 +187,10 @@ impl Display for ContentRangeSpec {
f.write_str("*")
}
}
ContentRangeSpec::Unregistered { ref unit, ref resp } => {
ContentRangeSpec::Unregistered {
ref unit,
ref resp,
} => {
try!(f.write_str(unit));
try!(f.write_str(" "));
f.write_str(resp)

View File

@ -1,6 +1,5 @@
use mime::{self, Mime};
use header::CONTENT_TYPE;
use mime::{self, Mime};
header! {
/// `Content-Type` header, defined in
@ -68,13 +67,15 @@ header! {
}
impl ContentType {
/// A constructor to easily create a `Content-Type: application/json` header.
/// A constructor to easily create a `Content-Type: application/json`
/// header.
#[inline]
pub fn json() -> ContentType {
ContentType(mime::APPLICATION_JSON)
}
/// A constructor to easily create a `Content-Type: text/plain; charset=utf-8` header.
/// A constructor to easily create a `Content-Type: text/plain;
/// charset=utf-8` header.
#[inline]
pub fn plaintext() -> ContentType {
ContentType(mime::TEXT_PLAIN_UTF_8)
@ -92,7 +93,8 @@ impl ContentType {
ContentType(mime::TEXT_XML)
}
/// A constructor to easily create a `Content-Type: application/www-form-url-encoded` header.
/// A constructor to easily create a `Content-Type:
/// application/www-form-url-encoded` header.
#[inline]
pub fn form_url_encoded() -> ContentType {
ContentType(mime::APPLICATION_WWW_FORM_URLENCODED)
@ -109,7 +111,8 @@ impl ContentType {
ContentType(mime::IMAGE_PNG)
}
/// A constructor to easily create a `Content-Type: application/octet-stream` header.
/// A constructor to easily create a `Content-Type:
/// application/octet-stream` header.
#[inline]
pub fn octet_stream() -> ContentType {
ContentType(mime::APPLICATION_OCTET_STREAM)

View File

@ -1,6 +1,5 @@
use header::{HttpDate, DATE};
use std::time::SystemTime;
use header::{DATE, HttpDate};
header! {
/// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2)

View File

@ -1,4 +1,4 @@
use header::{ETAG, EntityTag};
use header::{EntityTag, ETAG};
header! {
/// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3)

View File

@ -1,4 +1,4 @@
use header::{EXPIRES, HttpDate};
use header::{HttpDate, EXPIRES};
header! {
/// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3)

View File

@ -1,4 +1,4 @@
use header::{IF_MATCH, EntityTag};
use header::{EntityTag, IF_MATCH};
header! {
/// `If-Match` header, defined in

View File

@ -1,4 +1,4 @@
use header::{IF_MODIFIED_SINCE, HttpDate};
use header::{HttpDate, IF_MODIFIED_SINCE};
header! {
/// `If-Modified-Since` header, defined in

View File

@ -1,4 +1,4 @@
use header::{IF_NONE_MATCH, EntityTag};
use header::{EntityTag, IF_NONE_MATCH};
header! {
/// `If-None-Match` header, defined in
@ -66,8 +66,8 @@ header! {
#[cfg(test)]
mod tests {
use super::IfNoneMatch;
use header::{EntityTag, Header, IF_NONE_MATCH};
use test::TestRequest;
use header::{IF_NONE_MATCH, Header, EntityTag};
#[test]
fn test_if_none_match() {
@ -77,8 +77,9 @@ mod tests {
if_none_match = Header::parse(&req);
assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Any));
let req = TestRequest::with_header(
IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..]).finish();
let req =
TestRequest::with_header(IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..])
.finish();
if_none_match = Header::parse(&req);
let mut entities: Vec<EntityTag> = Vec::new();

View File

@ -1,10 +1,10 @@
use std::fmt::{self, Display, Write};
use error::ParseError;
use httpmessage::HttpMessage;
use http::header;
use header::from_one_raw_str;
use header::{IntoHeaderValue, Header, HeaderName, HeaderValue,
EntityTag, HttpDate, Writer, InvalidHeaderValueBytes};
use header::{EntityTag, Header, HeaderName, HeaderValue, HttpDate, IntoHeaderValue,
InvalidHeaderValueBytes, Writer};
use http::header;
use httpmessage::HttpMessage;
use std::fmt::{self, Display, Write};
/// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2)
///
@ -36,16 +36,19 @@ use header::{IntoHeaderValue, Header, HeaderName, HeaderValue,
///
/// ```rust
/// use actix_web::HttpResponse;
/// use actix_web::http::header::{IfRange, EntityTag};
/// use actix_web::http::header::{EntityTag, IfRange};
///
/// let mut builder = HttpResponse::Ok();
/// builder.set(IfRange::EntityTag(EntityTag::new(false, "xyzzy".to_owned())));
/// builder.set(IfRange::EntityTag(EntityTag::new(
/// false,
/// "xyzzy".to_owned(),
/// )));
/// ```
///
/// ```rust
/// use actix_web::HttpResponse;
/// use actix_web::http::header::IfRange;
/// use std::time::{SystemTime, Duration};
/// use std::time::{Duration, SystemTime};
///
/// let mut builder = HttpResponse::Ok();
/// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24);
@ -64,7 +67,9 @@ impl Header for IfRange {
header::IF_RANGE
}
#[inline]
fn parse<T>(msg: &T) -> Result<Self, ParseError> where T: HttpMessage
fn parse<T>(msg: &T) -> Result<Self, ParseError>
where
T: HttpMessage,
{
let etag: Result<EntityTag, _> =
from_one_raw_str(msg.headers().get(header::IF_RANGE));
@ -99,12 +104,11 @@ impl IntoHeaderValue for IfRange {
}
}
#[cfg(test)]
mod test_if_range {
use std::str;
use header::*;
use super::IfRange as HeaderField;
use header::*;
use std::str;
test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]);
test_header!(test2, vec![b"\"xyzzy\""]);
test_header!(test3, vec![b"this-is-invalid"], None::<IfRange>);

View File

@ -1,4 +1,4 @@
use header::{IF_UNMODIFIED_SINCE, HttpDate};
use header::{HttpDate, IF_UNMODIFIED_SINCE};
header! {
/// `If-Unmodified-Since` header, defined in

View File

@ -1,4 +1,4 @@
use header::{LAST_MODIFIED, HttpDate};
use header::{HttpDate, LAST_MODIFIED};
header! {
/// `Last-Modified` header, defined in

View File

@ -5,6 +5,7 @@
//! Several header fields use MIME values for their contents. Keeping with the
//! strongly-typed theme, the [mime](https://docs.rs/mime) crate
//! is used, such as `ContentType(pub Mime)`.
#![cfg_attr(rustfmt, rustfmt_skip)]
pub use self::accept_charset::AcceptCharset;
//pub use self::accept_encoding::AcceptEncoding;

View File

@ -1,8 +1,8 @@
use std::fmt::{self, Display};
use std::str::FromStr;
use header::parsing::from_one_raw_str;
use header::{Header, Raw};
use header::parsing::{from_one_raw_str};
/// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1)
///
@ -65,7 +65,7 @@ pub enum Range {
Bytes(Vec<ByteRangeSpec>),
/// Custom range, with unit not registered at IANA
/// (`other-range-unit`: String , `other-range-set`: String)
Unregistered(String, String)
Unregistered(String, String),
}
/// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`.
@ -77,25 +77,25 @@ pub enum ByteRangeSpec {
/// Get all bytes starting from x ("x-")
AllFrom(u64),
/// Get last x bytes ("-x")
Last(u64)
Last(u64),
}
impl ByteRangeSpec {
/// Given the full length of the entity, attempt to normalize the byte range
/// into an satisfiable end-inclusive (from, to) range.
///
/// The resulting range is guaranteed to be a satisfiable range within the bounds
/// of `0 <= from <= to < full_length`.
/// The resulting range is guaranteed to be a satisfiable range within the
/// bounds of `0 <= from <= to < full_length`.
///
/// If the byte range is deemed unsatisfiable, `None` is returned.
/// An unsatisfiable range is generally cause for a server to either reject
/// the client request with a `416 Range Not Satisfiable` status code, or to
/// simply ignore the range header and serve the full entity using a `200 OK`
/// status code.
/// simply ignore the range header and serve the full entity using a `200
/// OK` status code.
///
/// This function closely follows [RFC 7233][1] section 2.1.
/// As such, it considers ranges to be satisfiable if they meet the following
/// conditions:
/// As such, it considers ranges to be satisfiable if they meet the
/// following conditions:
///
/// > If a valid byte-range-set includes at least one byte-range-spec with
/// a first-byte-pos that is less than the current length of the
@ -125,14 +125,14 @@ impl ByteRangeSpec {
} else {
None
}
},
}
&ByteRangeSpec::AllFrom(from) => {
if from < full_length {
Some((from, full_length - 1))
} else {
None
}
},
}
&ByteRangeSpec::Last(last) => {
if last > 0 {
// From the RFC: If the selected representation is shorter
@ -160,11 +160,15 @@ impl Range {
/// Get byte range header with multiple subranges
/// ("bytes=from1-to1,from2-to2,fromX-toX")
pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range {
Range::Bytes(ranges.iter().map(|r| ByteRangeSpec::FromTo(r.0, r.1)).collect())
Range::Bytes(
ranges
.iter()
.map(|r| ByteRangeSpec::FromTo(r.0, r.1))
.collect(),
)
}
}
impl fmt::Display for ByteRangeSpec {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -175,7 +179,6 @@ impl fmt::Display for ByteRangeSpec {
}
}
impl fmt::Display for Range {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -189,10 +192,10 @@ impl fmt::Display for Range {
try!(Display::fmt(range, f));
}
Ok(())
},
}
Range::Unregistered(ref unit, ref range_str) => {
write!(f, "{}={}", unit, range_str)
},
}
}
}
}
@ -211,11 +214,10 @@ impl FromStr for Range {
}
Ok(Range::Bytes(ranges))
}
(Some(unit), Some(range_str)) if unit != "" && range_str != "" => {
Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned()))
},
_ => Err(::Error::Header)
(Some(unit), Some(range_str)) if unit != "" && range_str != "" => Ok(
Range::Unregistered(unit.to_owned(), range_str.to_owned()),
),
_ => Err(::Error::Header),
}
}
}
@ -227,19 +229,20 @@ impl FromStr for ByteRangeSpec {
let mut parts = s.splitn(2, '-');
match (parts.next(), parts.next()) {
(Some(""), Some(end)) => {
end.parse().or(Err(::Error::Header)).map(ByteRangeSpec::Last)
},
(Some(start), Some("")) => {
start.parse().or(Err(::Error::Header)).map(ByteRangeSpec::AllFrom)
},
(Some(start), Some(end)) => {
match (start.parse(), end.parse()) {
(Ok(start), Ok(end)) if start <= end => Ok(ByteRangeSpec::FromTo(start, end)),
_ => Err(::Error::Header)
(Some(""), Some(end)) => end.parse()
.or(Err(::Error::Header))
.map(ByteRangeSpec::Last),
(Some(start), Some("")) => start
.parse()
.or(Err(::Error::Header))
.map(ByteRangeSpec::AllFrom),
(Some(start), Some(end)) => match (start.parse(), end.parse()) {
(Ok(start), Ok(end)) if start <= end => {
Ok(ByteRangeSpec::FromTo(start, end))
}
_ => Err(::Error::Header),
},
_ => Err(::Error::Header)
_ => Err(::Error::Header),
}
}
}
@ -248,14 +251,13 @@ fn from_comma_delimited<T: FromStr>(s: &str) -> Vec<T> {
s.split(',')
.filter_map(|x| match x.trim() {
"" => None,
y => Some(y)
y => Some(y),
})
.filter_map(|x| x.parse().ok())
.collect()
}
impl Header for Range {
fn header_name() -> &'static str {
static NAME: &'static str = "Range";
NAME
@ -268,51 +270,52 @@ impl Header for Range {
fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result {
f.fmt_line(self)
}
}
#[test]
fn test_parse_bytes_range_valid() {
let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap();
let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap();
let r3 = Range::bytes(1, 100);
let r3 = Range::bytes(1, 100);
assert_eq!(r, r2);
assert_eq!(r2, r3);
let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap();
let r2: Range = Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap();
let r3 = Range::Bytes(
vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::AllFrom(200)]
);
let r2: Range =
Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap();
let r3 = Range::Bytes(vec![
ByteRangeSpec::FromTo(1, 100),
ByteRangeSpec::AllFrom(200),
]);
assert_eq!(r, r2);
assert_eq!(r2, r3);
let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap();
let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap();
let r3 = Range::Bytes(
vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::Last(100)]
);
let r3 = Range::Bytes(vec![
ByteRangeSpec::FromTo(1, 100),
ByteRangeSpec::Last(100),
]);
assert_eq!(r, r2);
assert_eq!(r2, r3);
let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned());
let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned());
assert_eq!(r, r2);
}
#[test]
fn test_parse_unregistered_range_valid() {
let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned());
let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned());
assert_eq!(r, r2);
let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned());
let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned());
assert_eq!(r, r2);
let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned());
let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned());
assert_eq!(r, r2);
}
@ -346,10 +349,10 @@ fn test_fmt() {
let mut headers = Headers::new();
headers.set(
Range::Bytes(
vec![ByteRangeSpec::FromTo(0, 1000), ByteRangeSpec::AllFrom(2000)]
));
headers.set(Range::Bytes(vec![
ByteRangeSpec::FromTo(0, 1000),
ByteRangeSpec::AllFrom(2000),
]));
assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n");
headers.clear();
@ -358,30 +361,74 @@ fn test_fmt() {
assert_eq!(&headers.to_string(), "Range: bytes=\r\n");
headers.clear();
headers.set(Range::Unregistered("custom".to_owned(), "1-xxx".to_owned()));
headers.set(Range::Unregistered(
"custom".to_owned(),
"1-xxx".to_owned(),
));
assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n");
}
#[test]
fn test_byte_range_spec_to_satisfiable_range() {
assert_eq!(Some((0, 0)), ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3));
assert_eq!(Some((1, 2)), ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3));
assert_eq!(Some((1, 2)), ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0));
assert_eq!(
Some((0, 0)),
ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3)
);
assert_eq!(
Some((1, 2)),
ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3)
);
assert_eq!(
Some((1, 2)),
ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)
);
assert_eq!(Some((0, 2)), ByteRangeSpec::AllFrom(0).to_satisfiable_range(3));
assert_eq!(Some((2, 2)), ByteRangeSpec::AllFrom(2).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::AllFrom(3).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::AllFrom(5).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::AllFrom(0).to_satisfiable_range(0));
assert_eq!(
Some((0, 2)),
ByteRangeSpec::AllFrom(0).to_satisfiable_range(3)
);
assert_eq!(
Some((2, 2)),
ByteRangeSpec::AllFrom(2).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::AllFrom(3).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::AllFrom(5).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::AllFrom(0).to_satisfiable_range(0)
);
assert_eq!(Some((1, 2)), ByteRangeSpec::Last(2).to_satisfiable_range(3));
assert_eq!(Some((2, 2)), ByteRangeSpec::Last(1).to_satisfiable_range(3));
assert_eq!(Some((0, 2)), ByteRangeSpec::Last(5).to_satisfiable_range(3));
assert_eq!(
Some((1, 2)),
ByteRangeSpec::Last(2).to_satisfiable_range(3)
);
assert_eq!(
Some((2, 2)),
ByteRangeSpec::Last(1).to_satisfiable_range(3)
);
assert_eq!(
Some((0, 2)),
ByteRangeSpec::Last(5).to_satisfiable_range(3)
);
assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0));
}

View File

@ -5,9 +5,9 @@ use std::fmt;
use std::str::FromStr;
use bytes::{Bytes, BytesMut};
use modhttp::{Error as HttpError};
use modhttp::header::GetAll;
use mime::Mime;
use modhttp::Error as HttpError;
use modhttp::header::GetAll;
pub use modhttp::header::*;
@ -21,11 +21,12 @@ pub use self::common::*;
#[doc(hidden)]
pub use self::shared::*;
#[doc(hidden)]
/// A trait for any object that will represent a header field and value.
pub trait Header where Self: IntoHeaderValue {
pub trait Header
where
Self: IntoHeaderValue,
{
/// Returns the name of the header field
fn name() -> HeaderName;
@ -112,7 +113,7 @@ pub enum ContentEncoding {
/// Automatically select encoding based on encoding negotiation
Auto,
/// A format using the Brotli algorithm
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
Br,
/// A format using the zlib structure with deflate algorithm
Deflate,
@ -123,19 +124,18 @@ pub enum ContentEncoding {
}
impl ContentEncoding {
#[inline]
pub fn is_compression(&self) -> bool {
match *self {
ContentEncoding::Identity | ContentEncoding::Auto => false,
_ => true
_ => true,
}
}
#[inline]
pub fn as_str(&self) -> &'static str {
match *self {
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
ContentEncoding::Br => "br",
ContentEncoding::Gzip => "gzip",
ContentEncoding::Deflate => "deflate",
@ -147,7 +147,7 @@ impl ContentEncoding {
/// default quality value
pub fn quality(&self) -> f64 {
match *self {
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
ContentEncoding::Br => 1.1,
ContentEncoding::Gzip => 1.0,
ContentEncoding::Deflate => 0.9,
@ -160,7 +160,7 @@ impl ContentEncoding {
impl<'a> From<&'a str> for ContentEncoding {
fn from(s: &'a str) -> ContentEncoding {
match s.trim().to_lowercase().as_ref() {
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
"br" => ContentEncoding::Br,
"gzip" => ContentEncoding::Gzip,
"deflate" => ContentEncoding::Deflate,
@ -176,7 +176,9 @@ pub(crate) struct Writer {
impl Writer {
fn new() -> Writer {
Writer{buf: BytesMut::new()}
Writer {
buf: BytesMut::new(),
}
}
fn take(&mut self) -> Bytes {
self.buf.take().freeze()
@ -199,18 +201,20 @@ impl fmt::Write for Writer {
#[inline]
#[doc(hidden)]
/// Reads a comma-delimited raw header into a Vec.
pub fn from_comma_delimited<T: FromStr>(all: GetAll<HeaderValue>)
-> Result<Vec<T>, ParseError>
{
pub fn from_comma_delimited<T: FromStr>(
all: GetAll<HeaderValue>
) -> Result<Vec<T>, ParseError> {
let mut result = Vec::new();
for h in all {
let s = h.to_str().map_err(|_| ParseError::Header)?;
result.extend(s.split(',')
.filter_map(|x| match x.trim() {
"" => None,
y => Some(y)
})
.filter_map(|x| x.trim().parse().ok()))
result.extend(
s.split(',')
.filter_map(|x| match x.trim() {
"" => None,
y => Some(y),
})
.filter_map(|x| x.trim().parse().ok()),
)
}
Ok(result)
}
@ -218,13 +222,11 @@ pub fn from_comma_delimited<T: FromStr>(all: GetAll<HeaderValue>)
#[inline]
#[doc(hidden)]
/// Reads a single string when parsing a header.
pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>)
-> Result<T, ParseError>
{
pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>) -> Result<T, ParseError> {
if let Some(line) = val {
let line = line.to_str().map_err(|_| ParseError::Header)?;
if !line.is_empty() {
return T::from_str(line).or(Err(ParseError::Header))
return T::from_str(line).or(Err(ParseError::Header));
}
}
Err(ParseError::Header)
@ -234,7 +236,8 @@ pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>)
#[doc(hidden)]
/// Format an array into a comma-delimited string.
pub fn fmt_comma_delimited<T>(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result
where T: fmt::Display
where
T: fmt::Display,
{
let mut iter = parts.iter();
if let Some(part) = iter.next() {

View File

@ -1,7 +1,7 @@
#![allow(unused, deprecated)]
use std::ascii::AsciiExt;
use std::fmt::{self, Display};
use std::str::FromStr;
use std::ascii::AsciiExt;
use self::Charset::*;
@ -12,9 +12,9 @@ use self::Charset::*;
/// See [http://www.iana.org/assignments/character-sets/character-sets.xhtml][url].
///
/// [url]: http://www.iana.org/assignments/character-sets/character-sets.xhtml
#[derive(Clone,Debug,PartialEq)]
#[derive(Clone, Debug, PartialEq)]
#[allow(non_camel_case_types)]
pub enum Charset{
pub enum Charset {
/// US ASCII
Us_Ascii,
/// ISO-8859-1
@ -64,7 +64,7 @@ pub enum Charset{
/// KOI8-R
Koi8_R,
/// An arbitrary charset specified as a string
Ext(String)
Ext(String),
}
impl Charset {
@ -94,7 +94,7 @@ impl Charset {
Gb2312 => "GB2312",
Big5 => "5",
Koi8_R => "KOI8-R",
Ext(ref s) => s
Ext(ref s) => s,
}
}
}
@ -133,18 +133,18 @@ impl FromStr for Charset {
"GB2312" => Gb2312,
"5" => Big5,
"KOI8-R" => Koi8_R,
s => Ext(s.to_owned())
s => Ext(s.to_owned()),
})
}
}
#[test]
fn test_parse() {
assert_eq!(Us_Ascii,"us-ascii".parse().unwrap());
assert_eq!(Us_Ascii,"US-Ascii".parse().unwrap());
assert_eq!(Us_Ascii,"US-ASCII".parse().unwrap());
assert_eq!(Shift_Jis,"Shift-JIS".parse().unwrap());
assert_eq!(Ext("ABCD".to_owned()),"abcd".parse().unwrap());
assert_eq!(Us_Ascii, "us-ascii".parse().unwrap());
assert_eq!(Us_Ascii, "US-Ascii".parse().unwrap());
assert_eq!(Us_Ascii, "US-ASCII".parse().unwrap());
assert_eq!(Shift_Jis, "Shift-JIS".parse().unwrap());
assert_eq!(Ext("ABCD".to_owned()), "abcd".parse().unwrap());
}
#[test]

View File

@ -1,7 +1,8 @@
use std::fmt;
use std::str;
pub use self::Encoding::{Chunked, Brotli, Gzip, Deflate, Compress, Identity, EncodingExt, Trailers};
pub use self::Encoding::{Brotli, Chunked, Compress, Deflate, EncodingExt, Gzip,
Identity, Trailers};
/// A value to represent an encoding used in `Transfer-Encoding`
/// or `Accept-Encoding` header.
@ -22,7 +23,7 @@ pub enum Encoding {
/// The `trailers` encoding.
Trailers,
/// Some other encoding that is less common, can be any String.
EncodingExt(String)
EncodingExt(String),
}
impl fmt::Display for Encoding {
@ -35,7 +36,7 @@ impl fmt::Display for Encoding {
Compress => "compress",
Identity => "identity",
Trailers => "trailers",
EncodingExt(ref s) => s.as_ref()
EncodingExt(ref s) => s.as_ref(),
})
}
}
@ -51,7 +52,7 @@ impl str::FromStr for Encoding {
"compress" => Ok(Compress),
"identity" => Ok(Identity),
"trailers" => Ok(Trailers),
_ => Ok(EncodingExt(s.to_owned()))
_ => Ok(EncodingExt(s.to_owned())),
}
}
}

View File

@ -1,21 +1,23 @@
use std::str::FromStr;
use header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer};
use std::fmt::{self, Display, Write};
use header::{HeaderValue, Writer, IntoHeaderValue, InvalidHeaderValueBytes};
use std::str::FromStr;
/// check that each char in the slice is either:
/// 1. `%x21`, or
/// 2. in the range `%x23` to `%x7E`, or
/// 3. above `%x80`
fn check_slice_validity(slice: &str) -> bool {
slice.bytes().all(|c|
c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80'))
slice
.bytes()
.all(|c| c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80'))
}
/// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3)
///
/// An entity tag consists of a string enclosed by two literal double quotes.
/// Preceding the first double quote is an optional weakness indicator,
/// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and `W/"xyzzy"`.
/// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and
/// `W/"xyzzy"`.
///
/// # ABNF
///
@ -28,9 +30,9 @@ fn check_slice_validity(slice: &str) -> bool {
/// ```
///
/// # Comparison
/// To check if two entity tags are equivalent in an application always use the `strong_eq` or
/// `weak_eq` methods based on the context of the Tag. Only use `==` to check if two tags are
/// identical.
/// To check if two entity tags are equivalent in an application always use the
/// `strong_eq` or `weak_eq` methods based on the context of the Tag. Only use
/// `==` to check if two tags are identical.
///
/// The example below shows the results for a set of entity-tag pairs and
/// both the weak and strong comparison function results:
@ -46,7 +48,7 @@ pub struct EntityTag {
/// Weakness indicator for the tag
pub weak: bool,
/// The opaque string in between the DQUOTEs
tag: String
tag: String,
}
impl EntityTag {
@ -85,8 +87,8 @@ impl EntityTag {
self.tag = tag
}
/// For strong comparison two entity-tags are equivalent if both are not weak and their
/// opaque-tags match character-by-character.
/// For strong comparison two entity-tags are equivalent if both are not
/// weak and their opaque-tags match character-by-character.
pub fn strong_eq(&self, other: &EntityTag) -> bool {
!self.weak && !other.weak && self.tag == other.tag
}
@ -131,13 +133,21 @@ impl FromStr for EntityTag {
}
// The etag is weak if its first char is not a DQUOTE.
if slice.len() >= 2 && slice.starts_with('"')
&& check_slice_validity(&slice[1..length-1]) {
&& check_slice_validity(&slice[1..length - 1])
{
// No need to check if the last char is a DQUOTE,
// we already did that above.
return Ok(EntityTag { weak: false, tag: slice[1..length-1].to_owned() });
return Ok(EntityTag {
weak: false,
tag: slice[1..length - 1].to_owned(),
});
} else if slice.len() >= 4 && slice.starts_with("W/\"")
&& check_slice_validity(&slice[3..length-1]) {
return Ok(EntityTag { weak: true, tag: slice[3..length-1].to_owned() });
&& check_slice_validity(&slice[3..length - 1])
{
return Ok(EntityTag {
weak: true,
tag: slice[3..length - 1].to_owned(),
});
}
Err(::error::ParseError::Header)
}
@ -149,7 +159,7 @@ impl IntoHeaderValue for EntityTag {
fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = Writer::new();
write!(wrt, "{}", self).unwrap();
unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.take()))}
unsafe { Ok(HeaderValue::from_shared_unchecked(wrt.take())) }
}
}
@ -160,22 +170,37 @@ mod tests {
#[test]
fn test_etag_parse_success() {
// Expected success
assert_eq!("\"foobar\"".parse::<EntityTag>().unwrap(),
EntityTag::strong("foobar".to_owned()));
assert_eq!("\"\"".parse::<EntityTag>().unwrap(),
EntityTag::strong("".to_owned()));
assert_eq!("W/\"weaktag\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("weaktag".to_owned()));
assert_eq!("W/\"\x65\x62\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("\x65\x62".to_owned()));
assert_eq!("W/\"\"".parse::<EntityTag>().unwrap(), EntityTag::weak("".to_owned()));
assert_eq!(
"\"foobar\"".parse::<EntityTag>().unwrap(),
EntityTag::strong("foobar".to_owned())
);
assert_eq!(
"\"\"".parse::<EntityTag>().unwrap(),
EntityTag::strong("".to_owned())
);
assert_eq!(
"W/\"weaktag\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("weaktag".to_owned())
);
assert_eq!(
"W/\"\x65\x62\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("\x65\x62".to_owned())
);
assert_eq!(
"W/\"\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("".to_owned())
);
}
#[test]
fn test_etag_parse_failures() {
// Expected failures
assert!("no-dquotes".parse::<EntityTag>().is_err());
assert!("w/\"the-first-w-is-case-sensitive\"".parse::<EntityTag>().is_err());
assert!(
"w/\"the-first-w-is-case-sensitive\""
.parse::<EntityTag>()
.is_err()
);
assert!("".parse::<EntityTag>().is_err());
assert!("\"unmatched-dquotes1".parse::<EntityTag>().is_err());
assert!("unmatched-dquotes2\"".parse::<EntityTag>().is_err());
@ -184,11 +209,26 @@ mod tests {
#[test]
fn test_etag_fmt() {
assert_eq!(format!("{}", EntityTag::strong("foobar".to_owned())), "\"foobar\"");
assert_eq!(format!("{}", EntityTag::strong("".to_owned())), "\"\"");
assert_eq!(format!("{}", EntityTag::weak("weak-etag".to_owned())), "W/\"weak-etag\"");
assert_eq!(format!("{}", EntityTag::weak("\u{0065}".to_owned())), "W/\"\x65\"");
assert_eq!(format!("{}", EntityTag::weak("".to_owned())), "W/\"\"");
assert_eq!(
format!("{}", EntityTag::strong("foobar".to_owned())),
"\"foobar\""
);
assert_eq!(
format!("{}", EntityTag::strong("".to_owned())),
"\"\""
);
assert_eq!(
format!("{}", EntityTag::weak("weak-etag".to_owned())),
"W/\"weak-etag\""
);
assert_eq!(
format!("{}", EntityTag::weak("\u{0065}".to_owned())),
"W/\"\x65\""
);
assert_eq!(
format!("{}", EntityTag::weak("".to_owned())),
"W/\"\""
);
}
#[test]

View File

@ -3,14 +3,13 @@ use std::io::Write;
use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use time;
use bytes::{BytesMut, BufMut};
use bytes::{BufMut, BytesMut};
use http::header::{HeaderValue, InvalidHeaderValueBytes};
use time;
use error::ParseError;
use header::IntoHeaderValue;
/// A timestamp with HTTP formatting and parsing
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct HttpDate(time::Tm);
@ -19,11 +18,10 @@ impl FromStr for HttpDate {
type Err = ParseError;
fn from_str(s: &str) -> Result<HttpDate, ParseError> {
match time::strptime(s, "%a, %d %b %Y %T %Z").or_else(|_| {
time::strptime(s, "%A, %d-%b-%y %T %Z")
}).or_else(|_| {
time::strptime(s, "%c")
}) {
match time::strptime(s, "%a, %d %b %Y %T %Z")
.or_else(|_| time::strptime(s, "%A, %d-%b-%y %T %Z"))
.or_else(|_| time::strptime(s, "%c"))
{
Ok(t) => Ok(HttpDate(t)),
Err(_) => Err(ParseError::Header),
}
@ -47,11 +45,14 @@ impl From<SystemTime> for HttpDate {
let tmspec = match sys.duration_since(UNIX_EPOCH) {
Ok(dur) => {
time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32)
},
}
Err(err) => {
let neg = err.duration();
time::Timespec::new(-(neg.as_secs() as i64), -(neg.subsec_nanos() as i32))
},
time::Timespec::new(
-(neg.as_secs() as i64),
-(neg.subsec_nanos() as i32),
)
}
};
HttpDate(time::at_utc(tmspec))
}
@ -63,7 +64,11 @@ impl IntoHeaderValue for HttpDate {
fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = BytesMut::with_capacity(29).writer();
write!(wrt, "{}", self.0.rfc822()).unwrap();
unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.get_mut().take().freeze()))}
unsafe {
Ok(HeaderValue::from_shared_unchecked(
wrt.get_mut().take().freeze(),
))
}
}
}
@ -80,18 +85,43 @@ impl From<HttpDate> for SystemTime {
#[cfg(test)]
mod tests {
use time::Tm;
use super::HttpDate;
use time::Tm;
const NOV_07: HttpDate = HttpDate(Tm {
tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8, tm_mday: 7, tm_mon: 10, tm_year: 94,
tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0});
tm_nsec: 0,
tm_sec: 37,
tm_min: 48,
tm_hour: 8,
tm_mday: 7,
tm_mon: 10,
tm_year: 94,
tm_wday: 0,
tm_isdst: 0,
tm_yday: 0,
tm_utcoff: 0,
});
#[test]
fn test_date() {
assert_eq!("Sun, 07 Nov 1994 08:48:37 GMT".parse::<HttpDate>().unwrap(), NOV_07);
assert_eq!("Sunday, 07-Nov-94 08:48:37 GMT".parse::<HttpDate>().unwrap(), NOV_07);
assert_eq!("Sun Nov 7 08:48:37 1994".parse::<HttpDate>().unwrap(), NOV_07);
assert_eq!(
"Sun, 07 Nov 1994 08:48:37 GMT"
.parse::<HttpDate>()
.unwrap(),
NOV_07
);
assert_eq!(
"Sunday, 07-Nov-94 08:48:37 GMT"
.parse::<HttpDate>()
.unwrap(),
NOV_07
);
assert_eq!(
"Sun Nov 7 08:48:37 1994"
.parse::<HttpDate>()
.unwrap(),
NOV_07
);
assert!("this-is-no-date".parse::<HttpDate>().is_err());
}
}

View File

@ -4,11 +4,11 @@ pub use self::charset::Charset;
pub use self::encoding::Encoding;
pub use self::entity::EntityTag;
pub use self::httpdate::HttpDate;
pub use self::quality_item::{q, qitem, Quality, QualityItem};
pub use language_tags::LanguageTag;
pub use self::quality_item::{Quality, QualityItem, qitem, q};
mod charset;
mod entity;
mod encoding;
mod entity;
mod httpdate;
mod quality_item;

View File

@ -13,11 +13,13 @@ use self::internal::IntoQuality;
///
/// # Implementation notes
///
/// The quality value is defined as a number between 0 and 1 with three decimal places. This means
/// there are 1001 possible values. Since floating point numbers are not exact and the smallest
/// floating point data type (`f32`) consumes four bytes, hyper uses an `u16` value to store the
/// quality internally. For performance reasons you may set quality directly to a value between
/// 0 and 1000 e.g. `Quality(532)` matches the quality `q=0.532`.
/// The quality value is defined as a number between 0 and 1 with three decimal
/// places. This means there are 1001 possible values. Since floating point
/// numbers are not exact and the smallest floating point data type (`f32`)
/// consumes four bytes, hyper uses an `u16` value to store the
/// quality internally. For performance reasons you may set quality directly to
/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality
/// `q=0.532`.
///
/// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1)
/// gives more information on quality values in HTTP header fields.
@ -61,7 +63,11 @@ impl<T: fmt::Display> fmt::Display for QualityItem<T> {
match self.quality.0 {
1000 => Ok(()),
0 => f.write_str("; q=0"),
x => write!(f, "; q=0.{}", format!("{:03}", x).trim_right_matches('0'))
x => write!(
f,
"; q=0.{}",
format!("{:03}", x).trim_right_matches('0')
),
}
}
}
@ -96,7 +102,7 @@ impl<T: str::FromStr> str::FromStr for QualityItem<T> {
} else {
return Err(::error::ParseError::Header);
}
},
}
Err(_) => return Err(::error::ParseError::Header),
}
}
@ -114,7 +120,10 @@ fn from_f32(f: f32) -> Quality {
// this function is only used internally. A check that `f` is within range
// should be done before calling this method. Just in case, this
// debug_assert should catch if we were forgetful
debug_assert!(f >= 0f32 && f <= 1f32, "q value must be between 0.0 and 1.0");
debug_assert!(
f >= 0f32 && f <= 1f32,
"q value must be between 0.0 and 1.0"
);
Quality((f * 1000f32) as u16)
}
@ -125,7 +134,7 @@ pub fn qitem<T>(item: T) -> QualityItem<T> {
}
/// Convenience function to create a `Quality` from a float or integer.
///
///
/// Implemented for `u16` and `f32`. Panics if value is out of range.
pub fn q<T: IntoQuality>(val: T) -> Quality {
val.into_quality()
@ -147,7 +156,10 @@ mod internal {
impl IntoQuality for f32 {
fn into_quality(self) -> Quality {
assert!(self >= 0f32 && self <= 1f32, "float must be between 0.0 and 1.0");
assert!(
self >= 0f32 && self <= 1f32,
"float must be between 0.0 and 1.0"
);
super::from_f32(self)
}
}
@ -159,7 +171,6 @@ mod internal {
}
}
pub trait Sealed {}
impl Sealed for u16 {}
impl Sealed for f32 {}
@ -167,8 +178,8 @@ mod internal {
#[cfg(test)]
mod tests {
use super::*;
use super::super::encoding::*;
use super::*;
#[test]
fn test_quality_item_fmt_q_1() {
@ -183,7 +194,7 @@ mod tests {
#[test]
fn test_quality_item_fmt_q_05() {
// Custom value
let x = QualityItem{
let x = QualityItem {
item: EncodingExt("identity".to_owned()),
quality: Quality(500),
};
@ -193,7 +204,7 @@ mod tests {
#[test]
fn test_quality_item_fmt_q_0() {
// Custom value
let x = QualityItem{
let x = QualityItem {
item: EncodingExt("identity".to_owned()),
quality: Quality(0),
};
@ -203,22 +214,46 @@ mod tests {
#[test]
fn test_quality_item_from_str1() {
let x: Result<QualityItem<Encoding>, _> = "chunked".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Chunked, quality: Quality(1000), });
assert_eq!(
x.unwrap(),
QualityItem {
item: Chunked,
quality: Quality(1000),
}
);
}
#[test]
fn test_quality_item_from_str2() {
let x: Result<QualityItem<Encoding>, _> = "chunked; q=1".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Chunked, quality: Quality(1000), });
assert_eq!(
x.unwrap(),
QualityItem {
item: Chunked,
quality: Quality(1000),
}
);
}
#[test]
fn test_quality_item_from_str3() {
let x: Result<QualityItem<Encoding>, _> = "gzip; q=0.5".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Gzip, quality: Quality(500), });
assert_eq!(
x.unwrap(),
QualityItem {
item: Gzip,
quality: Quality(500),
}
);
}
#[test]
fn test_quality_item_from_str4() {
let x: Result<QualityItem<Encoding>, _> = "gzip; q=0.273".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Gzip, quality: Quality(273), });
assert_eq!(
x.unwrap(),
QualityItem {
item: Gzip,
quality: Quality(273),
}
);
}
#[test]
fn test_quality_item_from_str5() {
@ -245,14 +280,14 @@ mod tests {
#[test]
#[should_panic] // FIXME - 32-bit msvc unwinding broken
#[cfg_attr(all(target_arch="x86", target_env="msvc"), ignore)]
#[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)]
fn test_quality_invalid() {
q(-1.0);
}
#[test]
#[should_panic] // FIXME - 32-bit msvc unwinding broken
#[cfg_attr(all(target_arch="x86", target_env="msvc"), ignore)]
#[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)]
fn test_quality_invalid2() {
q(2.0);
}
@ -260,6 +295,10 @@ mod tests {
#[test]
fn test_fuzzing_bugs() {
assert!("99999;".parse::<QualityItem<String>>().is_err());
assert!("\x0d;;;=\u{d6aa}==".parse::<QualityItem<String>>().is_err())
assert!(
"\x0d;;;=\u{d6aa}=="
.parse::<QualityItem<String>>()
.is_err()
)
}
}

View File

@ -1,7 +1,7 @@
//! Various helpers
use regex::Regex;
use http::{header, StatusCode};
use regex::Regex;
use handler::Handler;
use httprequest::HttpRequest;
@ -24,9 +24,11 @@ use httpresponse::HttpResponse;
/// defined with trailing slash and the request comes without it, it will
/// append it automatically.
///
/// If *merge* is *true*, merge multiple consecutive slashes in the path into one.
/// If *merge* is *true*, merge multiple consecutive slashes in the path into
/// one.
///
/// This handler designed to be use as a handler for application's *default resource*.
/// This handler designed to be use as a handler for application's *default
/// resource*.
///
/// ```rust
/// # extern crate actix_web;
@ -55,7 +57,8 @@ pub struct NormalizePath {
impl Default for NormalizePath {
/// Create default `NormalizePath` instance, *append* is set to *true*,
/// *merge* is set to *true* and *redirect* is set to `StatusCode::MOVED_PERMANENTLY`
/// *merge* is set to *true* and *redirect* is set to
/// `StatusCode::MOVED_PERMANENTLY`
fn default() -> NormalizePath {
NormalizePath {
append: true,
@ -91,7 +94,11 @@ impl<S> Handler<S> for NormalizePath {
let p = self.re_merge.replace_all(req.path(), "/");
if p.len() != req.path().len() {
if router.has_route(p.as_ref()) {
let p = if !query.is_empty() { p + "?" + query } else { p };
let p = if !query.is_empty() {
p + "?" + query
} else {
p
};
return HttpResponse::build(self.redirect)
.header(header::LOCATION, p.as_ref())
.finish();
@ -100,10 +107,14 @@ impl<S> Handler<S> for NormalizePath {
if self.append && !p.ends_with('/') {
let p = p.as_ref().to_owned() + "/";
if router.has_route(&p) {
let p = if !query.is_empty() { p + "?" + query } else { p };
let p = if !query.is_empty() {
p + "?" + query
} else {
p
};
return HttpResponse::build(self.redirect)
.header(header::LOCATION, p.as_str())
.finish()
.finish();
}
}
@ -113,11 +124,13 @@ impl<S> Handler<S> for NormalizePath {
if router.has_route(p) {
let mut req = HttpResponse::build(self.redirect);
return if !query.is_empty() {
req.header(header::LOCATION, (p.to_owned() + "?" + query).as_str())
req.header(
header::LOCATION,
(p.to_owned() + "?" + query).as_str(),
)
} else {
req.header(header::LOCATION, p)
}
.finish();
}.finish();
}
}
} else if p.ends_with('/') {
@ -126,12 +139,13 @@ impl<S> Handler<S> for NormalizePath {
if router.has_route(p) {
let mut req = HttpResponse::build(self.redirect);
return if !query.is_empty() {
req.header(header::LOCATION,
(p.to_owned() + "?" + query).as_str())
req.header(
header::LOCATION,
(p.to_owned() + "?" + query).as_str(),
)
} else {
req.header(header::LOCATION, p)
}
.finish();
}.finish();
}
}
}
@ -139,7 +153,11 @@ impl<S> Handler<S> for NormalizePath {
if self.append && !req.path().ends_with('/') {
let p = req.path().to_owned() + "/";
if router.has_route(&p) {
let p = if !query.is_empty() { p + "?" + query } else { p };
let p = if !query.is_empty() {
p + "?" + query
} else {
p
};
return HttpResponse::build(self.redirect)
.header(header::LOCATION, p.as_str())
.finish();
@ -153,9 +171,9 @@ impl<S> Handler<S> for NormalizePath {
#[cfg(test)]
mod tests {
use super::*;
use application::App;
use http::{header, Method};
use test::TestRequest;
use application::App;
fn index(_req: HttpRequest) -> HttpResponse {
HttpResponse::new(StatusCode::OK)
@ -170,17 +188,32 @@ mod tests {
.finish();
// trailing slashes
let params =
vec![("/resource1", "", StatusCode::OK),
("/resource1/", "/resource1", StatusCode::MOVED_PERMANENTLY),
("/resource2", "/resource2/", StatusCode::MOVED_PERMANENTLY),
("/resource2/", "", StatusCode::OK),
("/resource1?p1=1&p2=2", "", StatusCode::OK),
("/resource1/?p1=1&p2=2", "/resource1?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY),
("/resource2?p1=1&p2=2", "/resource2/?p1=1&p2=2",
StatusCode::MOVED_PERMANENTLY),
("/resource2/?p1=1&p2=2", "", StatusCode::OK)
];
let params = vec![
("/resource1", "", StatusCode::OK),
(
"/resource1/",
"/resource1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2",
"/resource2/",
StatusCode::MOVED_PERMANENTLY,
),
("/resource2/", "", StatusCode::OK),
("/resource1?p1=1&p2=2", "", StatusCode::OK),
(
"/resource1/?p1=1&p2=2",
"/resource1?p1=1&p2=2",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2?p1=1&p2=2",
"/resource2/?p1=1&p2=2",
StatusCode::MOVED_PERMANENTLY,
),
("/resource2/?p1=1&p2=2", "", StatusCode::OK),
];
for (path, target, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish());
let resp = app.run(req);
@ -189,7 +222,12 @@ mod tests {
if !target.is_empty() {
assert_eq!(
target,
r.headers().get(header::LOCATION).unwrap().to_str().unwrap());
r.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
);
}
}
}
@ -199,19 +237,25 @@ mod tests {
let mut app = App::new()
.resource("/resource1", |r| r.method(Method::GET).f(index))
.resource("/resource2/", |r| r.method(Method::GET).f(index))
.default_resource(|r| r.h(
NormalizePath::new(false, true, StatusCode::MOVED_PERMANENTLY)))
.default_resource(|r| {
r.h(NormalizePath::new(
false,
true,
StatusCode::MOVED_PERMANENTLY,
))
})
.finish();
// trailing slashes
let params = vec![("/resource1", StatusCode::OK),
("/resource1/", StatusCode::MOVED_PERMANENTLY),
("/resource2", StatusCode::NOT_FOUND),
("/resource2/", StatusCode::OK),
("/resource1?p1=1&p2=2", StatusCode::OK),
("/resource1/?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY),
("/resource2?p1=1&p2=2", StatusCode::NOT_FOUND),
("/resource2/?p1=1&p2=2", StatusCode::OK)
let params = vec![
("/resource1", StatusCode::OK),
("/resource1/", StatusCode::MOVED_PERMANENTLY),
("/resource2", StatusCode::NOT_FOUND),
("/resource2/", StatusCode::OK),
("/resource1?p1=1&p2=2", StatusCode::OK),
("/resource1/?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY),
("/resource2?p1=1&p2=2", StatusCode::NOT_FOUND),
("/resource2/?p1=1&p2=2", StatusCode::OK),
];
for (path, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish());
@ -232,21 +276,77 @@ mod tests {
// trailing slashes
let params = vec![
("/resource1/a/b", "", StatusCode::OK),
("/resource1/", "/resource1", StatusCode::MOVED_PERMANENTLY),
("/resource1//", "/resource1", StatusCode::MOVED_PERMANENTLY),
("//resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("//resource1//a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("//resource1//a//b//", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("///resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
(
"/resource1/",
"/resource1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource1//",
"/resource1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b//",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a//b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
("/resource1/a/b?p=1", "", StatusCode::OK),
("//resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("//resource1//a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("///resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a//b//?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
(
"//resource1//a//b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a//b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a//b//?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
];
for (path, target, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish());
@ -256,7 +356,12 @@ mod tests {
if !target.is_empty() {
assert_eq!(
target,
r.headers().get(header::LOCATION).unwrap().to_str().unwrap());
r.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
);
}
}
}
@ -274,38 +379,158 @@ mod tests {
// trailing slashes
let params = vec![
("/resource1/a/b", "", StatusCode::OK),
("/resource1/a/b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b//", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("///resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("///resource1//a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY),
("/resource2/a/b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
(
"/resource1/a/b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b//",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2/a/b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
("/resource2/a/b/", "", StatusCode::OK),
("//resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("///resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("///resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("/////resource2/a///b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
("/////resource2/a///b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY),
(
"//resource2//a//b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
("/resource1/a/b?p=1", "", StatusCode::OK),
("/resource1/a/b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("///resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("///resource1//a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource1/a///b//?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY),
("/resource2/a/b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("//resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("///resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("///resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource2/a///b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
("/////resource2/a///b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY),
(
"/resource1/a/b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b//?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2/a/b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
];
for (path, target, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish());
@ -314,7 +539,13 @@ mod tests {
assert_eq!(r.status(), code);
if !target.is_empty() {
assert_eq!(
target, r.headers().get(header::LOCATION).unwrap().to_str().unwrap());
target,
r.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
);
}
}
}

View File

@ -4,127 +4,155 @@ use http::StatusCode;
use body::Body;
use error::Error;
use handler::{Reply, Handler, RouteHandler, Responder};
use handler::{Handler, Reply, Responder, RouteHandler};
use httprequest::HttpRequest;
use httpresponse::{HttpResponse, HttpResponseBuilder};
#[deprecated(since="0.5.0", note="please use `HttpResponse::Ok()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Ok()` instead")]
pub const HttpOk: StaticResponse = StaticResponse(StatusCode::OK);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Created()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Created()` instead")]
pub const HttpCreated: StaticResponse = StaticResponse(StatusCode::CREATED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Accepted()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Accepted()` instead")]
pub const HttpAccepted: StaticResponse = StaticResponse(StatusCode::ACCEPTED);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::pNonAuthoritativeInformation()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::pNonAuthoritativeInformation()` instead")]
pub const HttpNonAuthoritativeInformation: StaticResponse =
StaticResponse(StatusCode::NON_AUTHORITATIVE_INFORMATION);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NoContent()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::NoContent()` instead")]
pub const HttpNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::ResetContent()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ResetContent()` instead")]
pub const HttpResetContent: StaticResponse = StaticResponse(StatusCode::RESET_CONTENT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PartialContent()` instead")]
pub const HttpPartialContent: StaticResponse = StaticResponse(StatusCode::PARTIAL_CONTENT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::MultiStatus()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PartialContent()` instead")]
pub const HttpPartialContent: StaticResponse =
StaticResponse(StatusCode::PARTIAL_CONTENT);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::MultiStatus()` instead")]
pub const HttpMultiStatus: StaticResponse = StaticResponse(StatusCode::MULTI_STATUS);
#[deprecated(since="0.5.0", note="please use `HttpResponse::AlreadyReported()` instead")]
pub const HttpAlreadyReported: StaticResponse = StaticResponse(StatusCode::ALREADY_REPORTED);
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::AlreadyReported()` instead")]
pub const HttpAlreadyReported: StaticResponse =
StaticResponse(StatusCode::ALREADY_REPORTED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::MultipleChoices()` instead")]
pub const HttpMultipleChoices: StaticResponse = StaticResponse(StatusCode::MULTIPLE_CHOICES);
#[deprecated(since="0.5.0", note="please use `HttpResponse::MovedPermanently()` instead")]
pub const HttpMovedPermanently: StaticResponse = StaticResponse(StatusCode::MOVED_PERMANENTLY);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Found()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::MultipleChoices()` instead")]
pub const HttpMultipleChoices: StaticResponse =
StaticResponse(StatusCode::MULTIPLE_CHOICES);
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::MovedPermanently()` instead")]
pub const HttpMovedPermanently: StaticResponse =
StaticResponse(StatusCode::MOVED_PERMANENTLY);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Found()` instead")]
pub const HttpFound: StaticResponse = StaticResponse(StatusCode::FOUND);
#[deprecated(since="0.5.0", note="please use `HttpResponse::SeeOther()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::SeeOther()` instead")]
pub const HttpSeeOther: StaticResponse = StaticResponse(StatusCode::SEE_OTHER);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotModified()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::NotModified()` instead")]
pub const HttpNotModified: StaticResponse = StaticResponse(StatusCode::NOT_MODIFIED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::UseProxy()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::UseProxy()` instead")]
pub const HttpUseProxy: StaticResponse = StaticResponse(StatusCode::USE_PROXY);
#[deprecated(since="0.5.0", note="please use `HttpResponse::TemporaryRedirect()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::TemporaryRedirect()` instead")]
pub const HttpTemporaryRedirect: StaticResponse =
StaticResponse(StatusCode::TEMPORARY_REDIRECT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PermanentRedirect()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PermanentRedirect()` instead")]
pub const HttpPermanentRedirect: StaticResponse =
StaticResponse(StatusCode::PERMANENT_REDIRECT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::BadRequest()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::BadRequest()` instead")]
pub const HttpBadRequest: StaticResponse = StaticResponse(StatusCode::BAD_REQUEST);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Unauthorized()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::Unauthorized()` instead")]
pub const HttpUnauthorized: StaticResponse = StaticResponse(StatusCode::UNAUTHORIZED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PaymentRequired()` instead")]
pub const HttpPaymentRequired: StaticResponse = StaticResponse(StatusCode::PAYMENT_REQUIRED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Forbidden()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PaymentRequired()` instead")]
pub const HttpPaymentRequired: StaticResponse =
StaticResponse(StatusCode::PAYMENT_REQUIRED);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Forbidden()` instead")]
pub const HttpForbidden: StaticResponse = StaticResponse(StatusCode::FORBIDDEN);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotFound()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::NotFound()` instead")]
pub const HttpNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND);
#[deprecated(since="0.5.0", note="please use `HttpResponse::MethodNotAllowed()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::MethodNotAllowed()` instead")]
pub const HttpMethodNotAllowed: StaticResponse =
StaticResponse(StatusCode::METHOD_NOT_ALLOWED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotAcceptable()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::NotAcceptable()` instead")]
pub const HttpNotAcceptable: StaticResponse = StaticResponse(StatusCode::NOT_ACCEPTABLE);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::ProxyAuthenticationRequired()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ProxyAuthenticationRequired()` instead")]
pub const HttpProxyAuthenticationRequired: StaticResponse =
StaticResponse(StatusCode::PROXY_AUTHENTICATION_REQUIRED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::RequestTimeout()` instead")]
pub const HttpRequestTimeout: StaticResponse = StaticResponse(StatusCode::REQUEST_TIMEOUT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Conflict()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::RequestTimeout()` instead")]
pub const HttpRequestTimeout: StaticResponse =
StaticResponse(StatusCode::REQUEST_TIMEOUT);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Conflict()` instead")]
pub const HttpConflict: StaticResponse = StaticResponse(StatusCode::CONFLICT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Gone()` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Gone()` instead")]
pub const HttpGone: StaticResponse = StaticResponse(StatusCode::GONE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::LengthRequired()` instead")]
pub const HttpLengthRequired: StaticResponse = StaticResponse(StatusCode::LENGTH_REQUIRED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PreconditionFailed()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::LengthRequired()` instead")]
pub const HttpLengthRequired: StaticResponse =
StaticResponse(StatusCode::LENGTH_REQUIRED);
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PreconditionFailed()` instead")]
pub const HttpPreconditionFailed: StaticResponse =
StaticResponse(StatusCode::PRECONDITION_FAILED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PayloadTooLarge()` instead")]
pub const HttpPayloadTooLarge: StaticResponse = StaticResponse(StatusCode::PAYLOAD_TOO_LARGE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::UriTooLong()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PayloadTooLarge()` instead")]
pub const HttpPayloadTooLarge: StaticResponse =
StaticResponse(StatusCode::PAYLOAD_TOO_LARGE);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::UriTooLong()` instead")]
pub const HttpUriTooLong: StaticResponse = StaticResponse(StatusCode::URI_TOO_LONG);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::UnsupportedMediaType()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::UnsupportedMediaType()` instead")]
pub const HttpUnsupportedMediaType: StaticResponse =
StaticResponse(StatusCode::UNSUPPORTED_MEDIA_TYPE);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::RangeNotSatisfiable()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::RangeNotSatisfiable()` instead")]
pub const HttpRangeNotSatisfiable: StaticResponse =
StaticResponse(StatusCode::RANGE_NOT_SATISFIABLE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::ExpectationFailed()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ExpectationFailed()` instead")]
pub const HttpExpectationFailed: StaticResponse =
StaticResponse(StatusCode::EXPECTATION_FAILED);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::InternalServerError()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::InternalServerError()` instead")]
pub const HttpInternalServerError: StaticResponse =
StaticResponse(StatusCode::INTERNAL_SERVER_ERROR);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotImplemented()` instead")]
pub const HttpNotImplemented: StaticResponse = StaticResponse(StatusCode::NOT_IMPLEMENTED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::BadGateway()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::NotImplemented()` instead")]
pub const HttpNotImplemented: StaticResponse =
StaticResponse(StatusCode::NOT_IMPLEMENTED);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::BadGateway()` instead")]
pub const HttpBadGateway: StaticResponse = StaticResponse(StatusCode::BAD_GATEWAY);
#[deprecated(since="0.5.0", note="please use `HttpResponse::ServiceUnavailable()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ServiceUnavailable()` instead")]
pub const HttpServiceUnavailable: StaticResponse =
StaticResponse(StatusCode::SERVICE_UNAVAILABLE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::GatewayTimeout()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::GatewayTimeout()` instead")]
pub const HttpGatewayTimeout: StaticResponse =
StaticResponse(StatusCode::GATEWAY_TIMEOUT);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::VersionNotSupported()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::VersionNotSupported()` instead")]
pub const HttpVersionNotSupported: StaticResponse =
StaticResponse(StatusCode::HTTP_VERSION_NOT_SUPPORTED);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::VariantAlsoNegotiates()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::VariantAlsoNegotiates()` instead")]
pub const HttpVariantAlsoNegotiates: StaticResponse =
StaticResponse(StatusCode::VARIANT_ALSO_NEGOTIATES);
#[deprecated(since="0.5.0",
note="please use `HttpResponse::InsufficientStorage()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::InsufficientStorage()` instead")]
pub const HttpInsufficientStorage: StaticResponse =
StaticResponse(StatusCode::INSUFFICIENT_STORAGE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::LoopDetected()` instead")]
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::LoopDetected()` instead")]
pub const HttpLoopDetected: StaticResponse = StaticResponse(StatusCode::LOOP_DETECTED);
#[deprecated(since="0.5.0", note="please use `HttpResponse` instead")]
#[deprecated(since = "0.5.0", note = "please use `HttpResponse` instead")]
#[derive(Copy, Clone, Debug)]
pub struct StaticResponse(StatusCode);
@ -186,14 +214,17 @@ macro_rules! STATIC_RESP {
pub fn $name() -> HttpResponseBuilder {
HttpResponse::build($status)
}
}
};
}
impl HttpResponse {
STATIC_RESP!(Ok, StatusCode::OK);
STATIC_RESP!(Created, StatusCode::CREATED);
STATIC_RESP!(Accepted, StatusCode::ACCEPTED);
STATIC_RESP!(NonAuthoritativeInformation, StatusCode::NON_AUTHORITATIVE_INFORMATION);
STATIC_RESP!(
NonAuthoritativeInformation,
StatusCode::NON_AUTHORITATIVE_INFORMATION
);
STATIC_RESP!(NoContent, StatusCode::NO_CONTENT);
STATIC_RESP!(ResetContent, StatusCode::RESET_CONTENT);
@ -218,7 +249,10 @@ impl HttpResponse {
STATIC_RESP!(Forbidden, StatusCode::FORBIDDEN);
STATIC_RESP!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED);
STATIC_RESP!(NotAcceptable, StatusCode::NOT_ACCEPTABLE);
STATIC_RESP!(ProxyAuthenticationRequired, StatusCode::PROXY_AUTHENTICATION_REQUIRED);
STATIC_RESP!(
ProxyAuthenticationRequired,
StatusCode::PROXY_AUTHENTICATION_REQUIRED
);
STATIC_RESP!(RequestTimeout, StatusCode::REQUEST_TIMEOUT);
STATIC_RESP!(Conflict, StatusCode::CONFLICT);
STATIC_RESP!(Gone, StatusCode::GONE);
@ -226,7 +260,10 @@ impl HttpResponse {
STATIC_RESP!(PreconditionFailed, StatusCode::PRECONDITION_FAILED);
STATIC_RESP!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE);
STATIC_RESP!(UriTooLong, StatusCode::URI_TOO_LONG);
STATIC_RESP!(UnsupportedMediaType, StatusCode::UNSUPPORTED_MEDIA_TYPE);
STATIC_RESP!(
UnsupportedMediaType,
StatusCode::UNSUPPORTED_MEDIA_TYPE
);
STATIC_RESP!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE);
STATIC_RESP!(ExpectationFailed, StatusCode::EXPECTATION_FAILED);
@ -235,16 +272,22 @@ impl HttpResponse {
STATIC_RESP!(BadGateway, StatusCode::BAD_GATEWAY);
STATIC_RESP!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE);
STATIC_RESP!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT);
STATIC_RESP!(VersionNotSupported, StatusCode::HTTP_VERSION_NOT_SUPPORTED);
STATIC_RESP!(VariantAlsoNegotiates, StatusCode::VARIANT_ALSO_NEGOTIATES);
STATIC_RESP!(
VersionNotSupported,
StatusCode::HTTP_VERSION_NOT_SUPPORTED
);
STATIC_RESP!(
VariantAlsoNegotiates,
StatusCode::VARIANT_ALSO_NEGOTIATES
);
STATIC_RESP!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE);
STATIC_RESP!(LoopDetected, StatusCode::LOOP_DETECTED);
}
#[cfg(test)]
mod tests {
use super::{Body, HttpBadRequest, HttpOk, HttpResponse};
use http::StatusCode;
use super::{HttpOk, HttpBadRequest, Body, HttpResponse};
#[test]
fn test_build() {

View File

@ -1,45 +1,45 @@
use std::str;
use bytes::{Bytes, BytesMut};
use futures::{Future, Stream, Poll};
use http_range::HttpRange;
use serde::de::DeserializeOwned;
use mime::Mime;
use serde_urlencoded;
use encoding::all::UTF_8;
use encoding::EncodingRef;
use encoding::types::{Encoding, DecoderTrap};
use encoding::all::UTF_8;
use encoding::label::encoding_from_whatwg_label;
use encoding::types::{DecoderTrap, Encoding};
use futures::{Future, Poll, Stream};
use http::{header, HeaderMap};
use http_range::HttpRange;
use mime::Mime;
use serde::de::DeserializeOwned;
use serde_urlencoded;
use std::str;
use json::JsonBody;
use error::{ContentTypeError, HttpRangeError, ParseError, PayloadError, UrlencodedError};
use header::Header;
use json::JsonBody;
use multipart::Multipart;
use error::{ParseError, ContentTypeError,
HttpRangeError, PayloadError, UrlencodedError};
/// Trait that implements general purpose operations on http messages
pub trait HttpMessage {
/// Read the message headers.
fn headers(&self) -> &HeaderMap;
#[doc(hidden)]
/// Get a header
fn get_header<H: Header>(&self) -> Option<H> where Self: Sized {
fn get_header<H: Header>(&self) -> Option<H>
where
Self: Sized,
{
if self.headers().contains_key(H::name()) {
H::parse(self).ok()
} else {
None
}
}
/// Read the request content type. If request does not contain
/// *Content-Type* header, empty str get returned.
fn content_type(&self) -> &str {
if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) {
if let Ok(content_type) = content_type.to_str() {
return content_type.split(';').next().unwrap().trim()
return content_type.split(';').next().unwrap().trim();
}
}
""
@ -73,7 +73,7 @@ pub trait HttpMessage {
Err(_) => Err(ContentTypeError::ParseError),
};
} else {
return Err(ContentTypeError::ParseError)
return Err(ContentTypeError::ParseError);
}
}
Ok(None)
@ -96,8 +96,10 @@ pub trait HttpMessage {
/// `size` is full size of response (file).
fn range(&self, size: u64) -> Result<Vec<HttpRange>, HttpRangeError> {
if let Some(range) = self.headers().get(header::RANGE) {
HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size)
.map_err(|e| e.into())
HttpRange::parse(
unsafe { str::from_utf8_unchecked(range.as_bytes()) },
size,
).map_err(|e| e.into())
} else {
Ok(Vec::new())
}
@ -105,8 +107,9 @@ pub trait HttpMessage {
/// Load http message body.
///
/// By default only 256Kb payload reads to a memory, then `PayloadError::Overflow`
/// get returned. Use `MessageBody::limit()` method to change upper limit.
/// By default only 256Kb payload reads to a memory, then
/// `PayloadError::Overflow` get returned. Use `MessageBody::limit()`
/// method to change upper limit.
///
/// ## Server example
///
@ -131,14 +134,15 @@ pub trait HttpMessage {
/// # fn main() {}
/// ```
fn body(self) -> MessageBody<Self>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized
where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{
MessageBody::new(self)
}
/// Parse `application/x-www-form-urlencoded` encoded request's body.
/// Return `UrlEncoded` future. Form can be deserialized to any type that implements
/// `Deserialize` trait from *serde*.
/// Return `UrlEncoded` future. Form can be deserialized to any type that
/// implements `Deserialize` trait from *serde*.
///
/// Returns error:
///
@ -167,7 +171,8 @@ pub trait HttpMessage {
/// # fn main() {}
/// ```
fn urlencoded<T: DeserializeOwned>(self) -> UrlEncoded<Self, T>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized
where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{
UrlEncoded::new(self)
}
@ -205,7 +210,8 @@ pub trait HttpMessage {
/// # fn main() {}
/// ```
fn json<T: DeserializeOwned>(self) -> JsonBody<Self, T>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized
where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{
JsonBody::new(self)
}
@ -247,7 +253,8 @@ pub trait HttpMessage {
/// # fn main() {}
/// ```
fn multipart(self) -> Multipart<Self>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized
where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{
let boundary = Multipart::boundary(self.headers());
Multipart::new(boundary, self)
@ -258,11 +265,10 @@ pub trait HttpMessage {
pub struct MessageBody<T> {
limit: usize,
req: Option<T>,
fut: Option<Box<Future<Item=Bytes, Error=PayloadError>>>,
fut: Option<Box<Future<Item = Bytes, Error = PayloadError>>>,
}
impl<T> MessageBody<T> {
/// Create `RequestBody` for request.
pub fn new(req: T) -> MessageBody<T> {
MessageBody {
@ -280,7 +286,8 @@ impl<T> MessageBody<T> {
}
impl<T> Future for MessageBody<T>
where T: HttpMessage + Stream<Item=Bytes, Error=PayloadError> + 'static
where
T: HttpMessage + Stream<Item = Bytes, Error = PayloadError> + 'static,
{
type Item = Bytes;
type Error = PayloadError;
@ -313,11 +320,14 @@ impl<T> Future for MessageBody<T>
Ok(body)
}
})
.map(|body| body.freeze())
.map(|body| body.freeze()),
));
}
self.fut.as_mut().expect("UrlEncoded could not be used second time").poll()
self.fut
.as_mut()
.expect("UrlEncoded could not be used second time")
.poll()
}
}
@ -325,7 +335,7 @@ impl<T> Future for MessageBody<T>
pub struct UrlEncoded<T, U> {
req: Option<T>,
limit: usize,
fut: Option<Box<Future<Item=U, Error=UrlencodedError>>>,
fut: Option<Box<Future<Item = U, Error = UrlencodedError>>>,
}
impl<T, U> UrlEncoded<T, U> {
@ -345,8 +355,9 @@ impl<T, U> UrlEncoded<T, U> {
}
impl<T, U> Future for UrlEncoded<T, U>
where T: HttpMessage + Stream<Item=Bytes, Error=PayloadError> + 'static,
U: DeserializeOwned + 'static
where
T: HttpMessage + Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static,
{
type Item = U;
type Error = UrlencodedError;
@ -354,7 +365,7 @@ impl<T, U> Future for UrlEncoded<T, U>
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(req) = self.req.take() {
if req.chunked().unwrap_or(false) {
return Err(UrlencodedError::Chunked)
return Err(UrlencodedError::Chunked);
} else if let Some(len) = req.headers().get(header::CONTENT_LENGTH) {
if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() {
@ -362,18 +373,19 @@ impl<T, U> Future for UrlEncoded<T, U>
return Err(UrlencodedError::Overflow);
}
} else {
return Err(UrlencodedError::UnknownLength)
return Err(UrlencodedError::UnknownLength);
}
} else {
return Err(UrlencodedError::UnknownLength)
return Err(UrlencodedError::UnknownLength);
}
}
// check content type
if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" {
return Err(UrlencodedError::ContentType)
return Err(UrlencodedError::ContentType);
}
let encoding = req.encoding().map_err(|_| UrlencodedError::ContentType)?;
let encoding = req.encoding()
.map_err(|_| UrlencodedError::ContentType)?;
// future
let limit = self.limit;
@ -392,7 +404,8 @@ impl<T, U> Future for UrlEncoded<T, U>
serde_urlencoded::from_bytes::<U>(&body)
.map_err(|_| UrlencodedError::Parse)
} else {
let body = encoding.decode(&body, DecoderTrap::Strict)
let body = encoding
.decode(&body, DecoderTrap::Strict)
.map_err(|_| UrlencodedError::Parse)?;
serde_urlencoded::from_str::<U>(&body)
.map_err(|_| UrlencodedError::Parse)
@ -401,19 +414,22 @@ impl<T, U> Future for UrlEncoded<T, U>
self.fut = Some(Box::new(fut));
}
self.fut.as_mut().expect("UrlEncoded could not be used second time").poll()
self.fut
.as_mut()
.expect("UrlEncoded could not be used second time")
.poll()
}
}
#[cfg(test)]
mod tests {
use super::*;
use mime;
use encoding::Encoding;
use encoding::all::ISO_8859_2;
use futures::Async;
use http::{Method, Version, Uri};
use http::{Method, Uri, Version};
use httprequest::HttpRequest;
use mime;
use std::str::FromStr;
use test::TestRequest;
@ -421,8 +437,9 @@ mod tests {
fn test_content_type() {
let req = TestRequest::with_header("content-type", "text/plain").finish();
assert_eq!(req.content_type(), "text/plain");
let req = TestRequest::with_header(
"content-type", "application/json; charset=utf=8").finish();
let req =
TestRequest::with_header("content-type", "application/json; charset=utf=8")
.finish();
assert_eq!(req.content_type(), "application/json");
let req = HttpRequest::default();
assert_eq!(req.content_type(), "");
@ -434,8 +451,9 @@ mod tests {
assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON));
let req = HttpRequest::default();
assert_eq!(req.mime_type().unwrap(), None);
let req = TestRequest::with_header(
"content-type", "application/json; charset=utf-8").finish();
let req =
TestRequest::with_header("content-type", "application/json; charset=utf-8")
.finish();
let mt = req.mime_type().unwrap().unwrap();
assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8));
assert_eq!(mt.type_(), mime::APPLICATION);
@ -445,7 +463,9 @@ mod tests {
#[test]
fn test_mime_type_error() {
let req = TestRequest::with_header(
"content-type", "applicationadfadsfasdflknadsfklnadsfjson").finish();
"content-type",
"applicationadfadsfasdflknadsfklnadsfjson",
).finish();
assert_eq!(Err(ContentTypeError::ParseError), req.mime_type());
}
@ -454,24 +474,32 @@ mod tests {
let req = HttpRequest::default();
assert_eq!(UTF_8.name(), req.encoding().unwrap().name());
let req = TestRequest::with_header(
"content-type", "application/json").finish();
let req = TestRequest::with_header("content-type", "application/json").finish();
assert_eq!(UTF_8.name(), req.encoding().unwrap().name());
let req = TestRequest::with_header(
"content-type", "application/json; charset=ISO-8859-2").finish();
"content-type",
"application/json; charset=ISO-8859-2",
).finish();
assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name());
}
#[test]
fn test_encoding_error() {
let req = TestRequest::with_header(
"content-type", "applicatjson").finish();
assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err());
let req = TestRequest::with_header("content-type", "applicatjson").finish();
assert_eq!(
Some(ContentTypeError::ParseError),
req.encoding().err()
);
let req = TestRequest::with_header(
"content-type", "application/json; charset=kkkttktk").finish();
assert_eq!(Some(ContentTypeError::UnknownEncoding), req.encoding().err());
"content-type",
"application/json; charset=kkkttktk",
).finish();
assert_eq!(
Some(ContentTypeError::UnknownEncoding),
req.encoding().err()
);
}
#[test]
@ -495,17 +523,26 @@ mod tests {
let req = HttpRequest::default();
assert!(!req.chunked().unwrap());
let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish();
let req =
TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish();
assert!(req.chunked().unwrap());
let mut headers = HeaderMap::new();
let s = unsafe{str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())};
let s = unsafe {
str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())
};
headers.insert(header::TRANSFER_ENCODING,
header::HeaderValue::from_str(s).unwrap());
headers.insert(
header::TRANSFER_ENCODING,
header::HeaderValue::from_str(s).unwrap(),
);
let req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert!(req.chunked().is_err());
}
@ -540,51 +577,75 @@ mod tests {
#[test]
fn test_urlencoded_error() {
let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish();
assert_eq!(req.urlencoded::<Info>()
.poll().err().unwrap(), UrlencodedError::Chunked);
let req =
TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish();
assert_eq!(
req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::Chunked
);
let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::CONTENT_LENGTH, "xxxx")
header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "xxxx")
.finish();
assert_eq!(req.urlencoded::<Info>()
.poll().err().unwrap(), UrlencodedError::UnknownLength);
assert_eq!(
req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::UnknownLength
);
let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::CONTENT_LENGTH, "1000000")
header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "1000000")
.finish();
assert_eq!(req.urlencoded::<Info>()
.poll().err().unwrap(), UrlencodedError::Overflow);
assert_eq!(
req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::Overflow
);
let req = TestRequest::with_header(
header::CONTENT_TYPE, "text/plain")
let req = TestRequest::with_header(header::CONTENT_TYPE, "text/plain")
.header(header::CONTENT_LENGTH, "10")
.finish();
assert_eq!(req.urlencoded::<Info>()
.poll().err().unwrap(), UrlencodedError::ContentType);
assert_eq!(
req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::ContentType
);
}
#[test]
fn test_urlencoded() {
let mut req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::CONTENT_LENGTH, "11")
header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "11")
.finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
let result = req.urlencoded::<Info>().poll().ok().unwrap();
assert_eq!(result, Async::Ready(Info{hello: "world".to_owned()}));
assert_eq!(
result,
Async::Ready(Info {
hello: "world".to_owned()
})
);
let mut req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded; charset=utf-8")
.header(header::CONTENT_LENGTH, "11")
header::CONTENT_TYPE,
"application/x-www-form-urlencoded; charset=utf-8",
).header(header::CONTENT_LENGTH, "11")
.finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
let result = req.urlencoded().poll().ok().unwrap();
assert_eq!(result, Async::Ready(Info{hello: "world".to_owned()}));
assert_eq!(
result,
Async::Ready(Info {
hello: "world".to_owned()
})
);
}
#[test]
@ -602,14 +663,16 @@ mod tests {
}
let mut req = HttpRequest::default();
req.payload_mut().unread_data(Bytes::from_static(b"test"));
req.payload_mut()
.unread_data(Bytes::from_static(b"test"));
match req.body().poll().ok().unwrap() {
Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")),
_ => unreachable!("error"),
}
let mut req = HttpRequest::default();
req.payload_mut().unread_data(Bytes::from_static(b"11111111111111"));
req.payload_mut()
.unread_data(Bytes::from_static(b"11111111111111"));
match req.body().limit(5).poll().err().unwrap() {
PayloadError::Overflow => (),
_ => unreachable!("error"),

View File

@ -1,30 +1,29 @@
//! HTTP Request message related code.
use std::{io, cmp, str, fmt, mem};
use std::rc::Rc;
use std::net::SocketAddr;
use std::borrow::Cow;
use bytes::Bytes;
use cookie::Cookie;
use futures::{Async, Stream, Poll};
use futures::future::{FutureResult, result};
use futures_cpupool::CpuPool;
use failure;
use url::{Url, form_urlencoded};
use http::{header, Uri, Method, Version, HeaderMap, Extensions, StatusCode};
use tokio_io::AsyncRead;
use futures::future::{result, FutureResult};
use futures::{Async, Poll, Stream};
use futures_cpupool::CpuPool;
use http::{header, Extensions, HeaderMap, Method, StatusCode, Uri, Version};
use percent_encoding::percent_decode;
use std::borrow::Cow;
use std::net::SocketAddr;
use std::rc::Rc;
use std::{cmp, fmt, io, mem, str};
use tokio_io::AsyncRead;
use url::{form_urlencoded, Url};
use body::Body;
use info::ConnectionInfo;
use param::Params;
use router::{Router, Resource};
use payload::Payload;
use error::{CookieParseError, Error, PayloadError, UrlGenerationError};
use handler::FromRequest;
use httpmessage::HttpMessage;
use httpresponse::{HttpResponse, HttpResponseBuilder};
use info::ConnectionInfo;
use param::Params;
use payload::Payload;
use router::{Resource, Router};
use server::helpers::SharedHttpInnerMessage;
use error::{Error, UrlGenerationError, CookieParseError, PayloadError};
pub struct HttpInnerMessage {
pub version: Version,
@ -42,14 +41,13 @@ pub struct HttpInnerMessage {
resource: RouterResource,
}
#[derive(Debug, Copy, Clone,PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq)]
enum RouterResource {
Notset,
Normal(u16),
}
impl Default for HttpInnerMessage {
fn default() -> HttpInnerMessage {
HttpInnerMessage {
method: Method::GET,
@ -70,7 +68,6 @@ impl Default for HttpInnerMessage {
}
impl HttpInnerMessage {
/// Checks if a connection should be kept alive.
#[inline]
pub fn keep_alive(&self) -> bool {
@ -79,8 +76,8 @@ impl HttpInnerMessage {
if self.version == Version::HTTP_10 && conn.contains("keep-alive") {
true
} else {
self.version == Version::HTTP_11 &&
!(conn.contains("close") || conn.contains("upgrade"))
self.version == Version::HTTP_11
&& !(conn.contains("close") || conn.contains("upgrade"))
}
} else {
false
@ -105,21 +102,20 @@ impl HttpInnerMessage {
}
}
lazy_static!{
lazy_static! {
static ref RESOURCE: Resource = Resource::unset();
}
/// An HTTP Request
pub struct HttpRequest<S=()>(SharedHttpInnerMessage, Option<Rc<S>>, Option<Router>);
pub struct HttpRequest<S = ()>(SharedHttpInnerMessage, Option<Rc<S>>, Option<Router>);
impl HttpRequest<()> {
/// Construct a new Request.
#[inline]
pub fn new(method: Method, uri: Uri,
version: Version, headers: HeaderMap, payload: Option<Payload>)
-> HttpRequest
{
pub fn new(
method: Method, uri: Uri, version: Version, headers: HeaderMap,
payload: Option<Payload>,
) -> HttpRequest {
HttpRequest(
SharedHttpInnerMessage::from_message(HttpInnerMessage {
method,
@ -142,7 +138,7 @@ impl HttpRequest<()> {
}
#[inline(always)]
#[cfg_attr(feature="cargo-clippy", allow(inline_always))]
#[cfg_attr(feature = "cargo-clippy", allow(inline_always))]
pub(crate) fn from_message(msg: SharedHttpInnerMessage) -> HttpRequest {
HttpRequest(msg, None, None)
}
@ -154,7 +150,6 @@ impl HttpRequest<()> {
}
}
impl<S> HttpMessage for HttpRequest<S> {
#[inline]
fn headers(&self) -> &HeaderMap {
@ -163,7 +158,6 @@ impl<S> HttpMessage for HttpRequest<S> {
}
impl<S> HttpRequest<S> {
#[inline]
/// Construct new http request with state.
pub fn change_state<NS>(&self, state: Rc<NS>) -> HttpRequest<NS> {
@ -211,8 +205,10 @@ impl<S> HttpRequest<S> {
#[inline]
#[doc(hidden)]
pub fn cpu_pool(&self) -> &CpuPool {
self.router().expect("HttpRequest has to have Router instance")
.server_settings().cpu_pool()
self.router()
.expect("HttpRequest has to have Router instance")
.server_settings()
.cpu_pool()
}
/// Create http response
@ -235,12 +231,18 @@ impl<S> HttpRequest<S> {
#[doc(hidden)]
pub fn prefix_len(&self) -> usize {
if let Some(router) = self.router() { router.prefix().len() } else { 0 }
if let Some(router) = self.router() {
router.prefix().len()
} else {
0
}
}
/// Read the Request Uri.
#[inline]
pub fn uri(&self) -> &Uri { &self.as_ref().uri }
pub fn uri(&self) -> &Uri {
&self.as_ref().uri
}
/// Returns mutable the Request Uri.
///
@ -252,7 +254,9 @@ impl<S> HttpRequest<S> {
/// Read the Request method.
#[inline]
pub fn method(&self) -> &Method { &self.as_ref().method }
pub fn method(&self) -> &Method {
&self.as_ref().method
}
/// Read the Request Version.
#[inline]
@ -277,14 +281,16 @@ impl<S> HttpRequest<S> {
/// Percent decoded path of this Request.
#[inline]
pub fn path_decoded(&self) -> Cow<str> {
percent_decode(self.uri().path().as_bytes()).decode_utf8().unwrap()
percent_decode(self.uri().path().as_bytes())
.decode_utf8()
.unwrap()
}
/// Get *ConnectionInfo* for correct request.
pub fn connection_info(&self) -> &ConnectionInfo {
if self.as_ref().info.is_none() {
let info: ConnectionInfo<'static> = unsafe{
mem::transmute(ConnectionInfo::new(self))};
let info: ConnectionInfo<'static> =
unsafe { mem::transmute(ConnectionInfo::new(self)) };
self.as_mut().info = Some(info);
}
self.as_ref().info.as_ref().unwrap()
@ -310,9 +316,12 @@ impl<S> HttpRequest<S> {
/// .finish();
/// }
/// ```
pub fn url_for<U, I>(&self, name: &str, elements: U) -> Result<Url, UrlGenerationError>
where U: IntoIterator<Item=I>,
I: AsRef<str>,
pub fn url_for<U, I>(
&self, name: &str, elements: U
) -> Result<Url, UrlGenerationError>
where
U: IntoIterator<Item = I>,
I: AsRef<str>,
{
if self.router().is_none() {
Err(UrlGenerationError::RouterNotAvailable)
@ -320,7 +329,12 @@ impl<S> HttpRequest<S> {
let path = self.router().unwrap().resource_path(name, elements)?;
if path.starts_with('/') {
let conn = self.connection_info();
Ok(Url::parse(&format!("{}://{}{}", conn.scheme(), conn.host(), path))?)
Ok(Url::parse(&format!(
"{}://{}{}",
conn.scheme(),
conn.host(),
path
))?)
} else {
Ok(Url::parse(&path)?)
}
@ -338,7 +352,7 @@ impl<S> HttpRequest<S> {
pub fn resource(&self) -> &Resource {
if let Some(ref router) = self.2 {
if let RouterResource::Normal(idx) = self.as_ref().resource {
return router.get_resource(idx as usize)
return router.get_resource(idx as usize);
}
}
&*RESOURCE
@ -353,7 +367,8 @@ impl<S> HttpRequest<S> {
/// Peer address is actual socket address, if proxy is used in front of
/// actix http server, then peer address would be address of this proxy.
///
/// To get client connection information `connection_info()` method should be used.
/// To get client connection information `connection_info()` method should
/// be used.
#[inline]
pub fn peer_addr(&self) -> Option<&SocketAddr> {
self.as_ref().addr.as_ref()
@ -368,13 +383,14 @@ impl<S> HttpRequest<S> {
/// Params is a container for url query parameters.
pub fn query(&self) -> &Params {
if !self.as_ref().query_loaded {
let params: &mut Params = unsafe{ mem::transmute(&mut self.as_mut().query) };
let params: &mut Params =
unsafe { mem::transmute(&mut self.as_mut().query) };
self.as_mut().query_loaded = true;
for (key, val) in form_urlencoded::parse(self.query_string().as_ref()) {
params.add(key, val);
}
}
unsafe{ mem::transmute(&self.as_ref().query) }
unsafe { mem::transmute(&self.as_ref().query) }
}
/// The query string in the URL.
@ -412,7 +428,7 @@ impl<S> HttpRequest<S> {
if let Ok(cookies) = self.cookies() {
for cookie in cookies {
if cookie.name() == name {
return Some(cookie)
return Some(cookie);
}
}
}
@ -423,16 +439,17 @@ impl<S> HttpRequest<S> {
///
/// Params is a container for url parameters.
/// Route supports glob patterns: * for a single wildcard segment and :param
/// for matching storing that segment of the request url in the Params object.
/// for matching storing that segment of the request url in the Params
/// object.
#[inline]
pub fn match_info(&self) -> &Params {
unsafe{ mem::transmute(&self.as_ref().params) }
unsafe { mem::transmute(&self.as_ref().params) }
}
/// Get mutable reference to request's Params.
#[inline]
pub fn match_info_mut(&mut self) -> &mut Params {
unsafe{ mem::transmute(&mut self.as_mut().params) }
unsafe { mem::transmute(&mut self.as_mut().params) }
}
/// Checks if a connection should be kept alive.
@ -444,7 +461,7 @@ impl<S> HttpRequest<S> {
pub(crate) fn upgrade(&self) -> bool {
if let Some(conn) = self.as_ref().headers.get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
return s.to_lowercase().contains("upgrade")
return s.to_lowercase().contains("upgrade");
}
}
self.as_ref().method == Method::CONNECT
@ -479,7 +496,6 @@ impl<S> HttpRequest<S> {
}
impl Default for HttpRequest<()> {
/// Construct default request
fn default() -> HttpRequest {
HttpRequest(SharedHttpInnerMessage::default(), None, None)
@ -492,8 +508,7 @@ impl<S> Clone for HttpRequest<S> {
}
}
impl<S: 'static> FromRequest<S> for HttpRequest<S>
{
impl<S: 'static> FromRequest<S> for HttpRequest<S> {
type Config = ();
type Result = FutureResult<Self, Error>;
@ -540,10 +555,13 @@ impl<S> io::Read for HttpRequest<S> {
}
}
Ok(Async::Ready(None)) => Ok(0),
Ok(Async::NotReady) =>
Err(io::Error::new(io::ErrorKind::WouldBlock, "Not ready")),
Err(e) =>
Err(io::Error::new(io::ErrorKind::Other, failure::Error::from(e).compat())),
Ok(Async::NotReady) => {
Err(io::Error::new(io::ErrorKind::WouldBlock, "Not ready"))
}
Err(e) => Err(io::Error::new(
io::ErrorKind::Other,
failure::Error::from(e).compat(),
)),
}
} else {
Ok(0)
@ -556,8 +574,12 @@ impl<S> AsyncRead for HttpRequest<S> {}
impl<S> fmt::Debug for HttpRequest<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(
f, "\nHttpRequest {:?} {}:{}",
self.as_ref().version, self.as_ref().method, self.path_decoded());
f,
"\nHttpRequest {:?} {}:{}",
self.as_ref().version,
self.as_ref().method,
self.path_decoded()
);
if !self.query_string().is_empty() {
let _ = writeln!(f, " query: ?{:?}", self.query_string());
}
@ -575,11 +597,11 @@ impl<S> fmt::Debug for HttpRequest<S> {
#[cfg(test)]
mod tests {
use super::*;
use http::{Uri, HttpTryFrom};
use router::Resource;
use http::{HttpTryFrom, Uri};
use resource::ResourceHandler;
use test::TestRequest;
use router::Resource;
use server::ServerSettings;
use test::TestRequest;
#[test]
fn test_debug() {
@ -652,12 +674,19 @@ mod tests {
#[test]
fn test_url_for() {
let req2 = HttpRequest::default();
assert_eq!(req2.url_for("unknown", &["test"]),
Err(UrlGenerationError::RouterNotAvailable));
assert_eq!(
req2.url_for("unknown", &["test"]),
Err(UrlGenerationError::RouterNotAvailable)
);
let mut resource = ResourceHandler::<()>::default();
resource.name("index");
let routes = vec!((Resource::new("index", "/user/{name}.{ext}"), Some(resource)));
let routes = vec![
(
Resource::new("index", "/user/{name}.{ext}"),
Some(resource),
),
];
let (router, _) = Router::new("/", ServerSettings::default(), routes);
assert!(router.has_route("/user/test.html"));
assert!(!router.has_route("/test/unknown"));
@ -665,12 +694,19 @@ mod tests {
let req = TestRequest::with_header(header::HOST, "www.rust-lang.org")
.finish_with_router(router);
assert_eq!(req.url_for("unknown", &["test"]),
Err(UrlGenerationError::ResourceNotFound));
assert_eq!(req.url_for("index", &["test"]),
Err(UrlGenerationError::NotEnoughElements));
assert_eq!(
req.url_for("unknown", &["test"]),
Err(UrlGenerationError::ResourceNotFound)
);
assert_eq!(
req.url_for("index", &["test"]),
Err(UrlGenerationError::NotEnoughElements)
);
let url = req.url_for("index", &["test", "html"]);
assert_eq!(url.ok().unwrap().as_str(), "http://www.rust-lang.org/user/test.html");
assert_eq!(
url.ok().unwrap().as_str(),
"http://www.rust-lang.org/user/test.html"
);
}
#[test]
@ -679,15 +715,22 @@ mod tests {
let mut resource = ResourceHandler::<()>::default();
resource.name("index");
let routes = vec![(Resource::new("index", "/user/{name}.{ext}"), Some(resource))];
let routes = vec![
(
Resource::new("index", "/user/{name}.{ext}"),
Some(resource),
),
];
let (router, _) = Router::new("/prefix/", ServerSettings::default(), routes);
assert!(router.has_route("/user/test.html"));
assert!(!router.has_route("/prefix/user/test.html"));
let req = req.with_state(Rc::new(()), router);
let url = req.url_for("index", &["test", "html"]);
assert_eq!(url.ok().unwrap().as_str(),
"http://www.rust-lang.org/prefix/user/test.html");
assert_eq!(
url.ok().unwrap().as_str(),
"http://www.rust-lang.org/prefix/user/test.html"
);
}
#[test]
@ -697,12 +740,19 @@ mod tests {
let mut resource = ResourceHandler::<()>::default();
resource.name("index");
let routes = vec![
(Resource::external("youtube", "https://youtube.com/watch/{video_id}"), None)];
(
Resource::external("youtube", "https://youtube.com/watch/{video_id}"),
None,
),
];
let (router, _) = Router::new::<()>("", ServerSettings::default(), routes);
assert!(!router.has_route("https://youtube.com/watch/unknown"));
let req = req.with_state(Rc::new(()), router);
let url = req.url_for("youtube", &["oHg5SJYRHA0"]);
assert_eq!(url.ok().unwrap().as_str(), "https://youtube.com/watch/oHg5SJYRHA0");
assert_eq!(
url.ok().unwrap().as_str(),
"https://youtube.com/watch/oHg5SJYRHA0"
);
}
}

View File

@ -1,30 +1,29 @@
//! Http response
use std::{mem, str, fmt};
use std::rc::Rc;
use std::io::Write;
use std::cell::UnsafeCell;
use std::collections::VecDeque;
use std::io::Write;
use std::rc::Rc;
use std::{fmt, mem, str};
use bytes::{BufMut, Bytes, BytesMut};
use cookie::{Cookie, CookieJar};
use bytes::{Bytes, BytesMut, BufMut};
use futures::Stream;
use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use serde_json;
use http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode, Version};
use serde::Serialize;
use serde_json;
use body::Body;
use client::ClientResponse;
use error::Error;
use handler::Responder;
use header::{Header, IntoHeaderValue, ContentEncoding};
use httprequest::HttpRequest;
use header::{ContentEncoding, Header, IntoHeaderValue};
use httpmessage::HttpMessage;
use client::ClientResponse;
use httprequest::HttpRequest;
/// max write buffer size 64k
pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536;
/// Represents various types of connection
#[derive(Copy, Clone, PartialEq, Debug)]
pub enum ConnectionType {
@ -37,7 +36,10 @@ pub enum ConnectionType {
}
/// An HTTP Response
pub struct HttpResponse(Option<Box<InnerHttpResponse>>, Rc<UnsafeCell<HttpResponsePool>>);
pub struct HttpResponse(
Option<Box<InnerHttpResponse>>,
Rc<UnsafeCell<HttpResponsePool>>,
);
impl Drop for HttpResponse {
fn drop(&mut self) {
@ -48,7 +50,6 @@ impl Drop for HttpResponse {
}
impl HttpResponse {
#[inline(always)]
#[cfg_attr(feature = "cargo-clippy", allow(inline_always))]
fn get_ref(&self) -> &InnerHttpResponse {
@ -103,7 +104,7 @@ impl HttpResponse {
response,
pool,
err: None,
cookies: None, // TODO: convert set-cookie headers
cookies: None, // TODO: convert set-cookie headers
}
}
@ -149,7 +150,10 @@ impl HttpResponse {
if let Some(reason) = self.get_ref().reason {
reason
} else {
self.get_ref().status.canonical_reason().unwrap_or("<unknown status code>")
self.get_ref()
.status
.canonical_reason()
.unwrap_or("<unknown status code>")
}
}
@ -241,9 +245,13 @@ impl HttpResponse {
impl fmt::Debug for HttpResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(f, "\nHttpResponse {:?} {}{}",
self.get_ref().version, self.get_ref().status,
self.get_ref().reason.unwrap_or(""));
let res = writeln!(
f,
"\nHttpResponse {:?} {}{}",
self.get_ref().version,
self.get_ref().status,
self.get_ref().reason.unwrap_or("")
);
let _ = writeln!(f, " encoding: {:?}", self.get_ref().encoding);
let _ = writeln!(f, " headers:");
for (key, val) in self.get_ref().headers.iter() {
@ -299,11 +307,12 @@ impl HttpResponseBuilder {
/// fn main() {}
/// ```
#[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self
{
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self {
if let Some(parts) = parts(&mut self.response, &self.err) {
match hdr.try_into() {
Ok(value) => { parts.headers.append(H::name(), value); }
Ok(value) => {
parts.headers.append(H::name(), value);
}
Err(e) => self.err = Some(e.into()),
}
}
@ -325,16 +334,17 @@ impl HttpResponseBuilder {
/// fn main() {}
/// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.response, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => {
match value.try_into() {
Ok(value) => { parts.headers.append(key, value); }
Err(e) => self.err = Some(e.into()),
Ok(key) => match value.try_into() {
Ok(value) => {
parts.headers.append(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
};
@ -354,8 +364,9 @@ impl HttpResponseBuilder {
/// Set content encoding.
///
/// By default `ContentEncoding::Auto` is used, which automatically
/// negotiates content encoding based on request's `Accept-Encoding` headers.
/// To enforce specific encoding, use specific ContentEncoding` value.
/// negotiates content encoding based on request's `Accept-Encoding`
/// headers. To enforce specific encoding, use specific
/// ContentEncoding` value.
#[inline]
pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self {
if let Some(parts) = parts(&mut self.response, &self.err) {
@ -408,11 +419,14 @@ impl HttpResponseBuilder {
/// Set response content type
#[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self
where HeaderValue: HttpTryFrom<V>
where
HeaderValue: HttpTryFrom<V>,
{
if let Some(parts) = parts(&mut self.response, &self.err) {
match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); },
Ok(value) => {
parts.headers.insert(header::CONTENT_TYPE, value);
}
Err(e) => self.err = Some(e.into()),
};
}
@ -452,12 +466,16 @@ impl HttpResponseBuilder {
jar.add(cookie.into_owned());
self.cookies = Some(jar)
} else {
self.cookies.as_mut().unwrap().add(cookie.into_owned());
self.cookies
.as_mut()
.unwrap()
.add(cookie.into_owned());
}
self
}
/// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()` method.
/// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()`
/// method.
pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self {
{
if self.cookies.is_none() {
@ -471,9 +489,11 @@ impl HttpResponseBuilder {
self
}
/// This method calls provided closure with builder reference if value is true.
/// This method calls provided closure with builder reference if value is
/// true.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where F: FnOnce(&mut HttpResponseBuilder)
where
F: FnOnce(&mut HttpResponseBuilder),
{
if value {
f(self);
@ -481,9 +501,11 @@ impl HttpResponseBuilder {
self
}
/// This method calls provided closure with builder reference if value is Some.
/// This method calls provided closure with builder reference if value is
/// Some.
pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self
where F: FnOnce(T, &mut HttpResponseBuilder)
where
F: FnOnce(T, &mut HttpResponseBuilder),
{
if let Some(val) = value {
f(val, self);
@ -494,8 +516,8 @@ impl HttpResponseBuilder {
/// Set write buffer capacity
///
/// This parameter makes sense only for streaming response
/// or actor. If write buffer reaches specified capacity, stream or actor get
/// paused.
/// or actor. If write buffer reaches specified capacity, stream or actor
/// get paused.
///
/// Default write buffer capacity is 64kb
pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self {
@ -510,9 +532,11 @@ impl HttpResponseBuilder {
/// `HttpResponseBuilder` can not be used after this call.
pub fn body<B: Into<Body>>(&mut self, body: B) -> HttpResponse {
if let Some(e) = self.err.take() {
return Error::from(e).into()
return Error::from(e).into();
}
let mut response = self.response.take().expect("cannot reuse response builder");
let mut response = self.response
.take()
.expect("cannot reuse response builder");
if let Some(ref jar) = self.cookies {
for cookie in jar.delta() {
match HeaderValue::from_str(&cookie.to_string()) {
@ -530,10 +554,13 @@ impl HttpResponseBuilder {
///
/// `HttpResponseBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> HttpResponse
where S: Stream<Item=Bytes, Error=E> + 'static,
E: Into<Error>,
where
S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error>,
{
self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into()))))
self.body(Body::Streaming(Box::new(
stream.map_err(|e| e.into()),
)))
}
/// Set a json body and generate `HttpResponse`
@ -542,19 +569,19 @@ impl HttpResponseBuilder {
pub fn json<T: Serialize>(&mut self, value: T) -> HttpResponse {
match serde_json::to_string(&value) {
Ok(body) => {
let contains =
if let Some(parts) = parts(&mut self.response, &self.err) {
parts.headers.contains_key(header::CONTENT_TYPE)
} else {
true
};
let contains = if let Some(parts) = parts(&mut self.response, &self.err)
{
parts.headers.contains_key(header::CONTENT_TYPE)
} else {
true
};
if !contains {
self.header(header::CONTENT_TYPE, "application/json");
}
self.body(body)
},
Err(e) => Error::from(e).into()
}
Err(e) => Error::from(e).into(),
}
}
@ -579,11 +606,11 @@ impl HttpResponseBuilder {
#[inline]
#[cfg_attr(feature = "cargo-clippy", allow(borrowed_box))]
fn parts<'a>(parts: &'a mut Option<Box<InnerHttpResponse>>, err: &Option<HttpError>)
-> Option<&'a mut Box<InnerHttpResponse>>
{
fn parts<'a>(
parts: &'a mut Option<Box<InnerHttpResponse>>, err: &Option<HttpError>
) -> Option<&'a mut Box<InnerHttpResponse>> {
if err.is_some() {
return None
return None;
}
parts.as_mut()
}
@ -628,8 +655,8 @@ impl Responder for &'static str {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK)
.content_type("text/plain; charset=utf-8")
.body(self))
.content_type("text/plain; charset=utf-8")
.body(self))
}
}
@ -647,8 +674,8 @@ impl Responder for &'static [u8] {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK)
.content_type("application/octet-stream")
.body(self))
.content_type("application/octet-stream")
.body(self))
}
}
@ -666,8 +693,8 @@ impl Responder for String {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK)
.content_type("text/plain; charset=utf-8")
.body(self))
.content_type("text/plain; charset=utf-8")
.body(self))
}
}
@ -685,8 +712,8 @@ impl<'a> Responder for &'a String {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK)
.content_type("text/plain; charset=utf-8")
.body(self))
.content_type("text/plain; charset=utf-8")
.body(self))
}
}
@ -704,8 +731,8 @@ impl Responder for Bytes {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK)
.content_type("application/octet-stream")
.body(self))
.content_type("application/octet-stream")
.body(self))
}
}
@ -723,8 +750,8 @@ impl Responder for BytesMut {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK)
.content_type("application/octet-stream")
.body(self))
.content_type("application/octet-stream")
.body(self))
}
}
@ -745,7 +772,9 @@ impl<'a> From<&'a ClientResponse> for HttpResponseBuilder {
impl<'a, S> From<&'a HttpRequest<S>> for HttpResponseBuilder {
fn from(req: &'a HttpRequest<S>) -> HttpResponseBuilder {
if let Some(router) = req.router() {
router.server_settings().get_response_builder(StatusCode::OK)
router
.server_settings()
.get_response_builder(StatusCode::OK)
} else {
HttpResponse::Ok()
}
@ -768,7 +797,6 @@ struct InnerHttpResponse {
}
impl InnerHttpResponse {
#[inline]
fn new(status: StatusCode, body: Body) -> InnerHttpResponse {
InnerHttpResponse {
@ -793,38 +821,41 @@ pub(crate) struct HttpResponsePool(VecDeque<Box<InnerHttpResponse>>);
thread_local!(static POOL: Rc<UnsafeCell<HttpResponsePool>> = HttpResponsePool::pool());
impl HttpResponsePool {
pub fn pool() -> Rc<UnsafeCell<HttpResponsePool>> {
Rc::new(UnsafeCell::new(HttpResponsePool(VecDeque::with_capacity(128))))
Rc::new(UnsafeCell::new(HttpResponsePool(
VecDeque::with_capacity(128),
)))
}
#[inline]
pub fn get_builder(pool: &Rc<UnsafeCell<HttpResponsePool>>, status: StatusCode)
-> HttpResponseBuilder
{
let p = unsafe{&mut *pool.as_ref().get()};
pub fn get_builder(
pool: &Rc<UnsafeCell<HttpResponsePool>>, status: StatusCode
) -> HttpResponseBuilder {
let p = unsafe { &mut *pool.as_ref().get() };
if let Some(mut msg) = p.0.pop_front() {
msg.status = status;
HttpResponseBuilder {
response: Some(msg),
pool: Some(Rc::clone(pool)),
err: None,
cookies: None }
cookies: None,
}
} else {
let msg = Box::new(InnerHttpResponse::new(status, Body::Empty));
HttpResponseBuilder {
response: Some(msg),
pool: Some(Rc::clone(pool)),
err: None,
cookies: None }
cookies: None,
}
}
}
#[inline]
pub fn get_response(pool: &Rc<UnsafeCell<HttpResponsePool>>,
status: StatusCode, body: Body) -> HttpResponse
{
let p = unsafe{&mut *pool.as_ref().get()};
pub fn get_response(
pool: &Rc<UnsafeCell<HttpResponsePool>>, status: StatusCode, body: Body
) -> HttpResponse {
let p = unsafe { &mut *pool.as_ref().get() };
if let Some(mut msg) = p.0.pop_front() {
msg.status = status;
msg.body = body;
@ -847,9 +878,10 @@ impl HttpResponsePool {
#[inline(always)]
#[cfg_attr(feature = "cargo-clippy", allow(boxed_local, inline_always))]
fn release(pool: &Rc<UnsafeCell<HttpResponsePool>>, mut inner: Box<InnerHttpResponse>)
{
let pool = unsafe{&mut *pool.as_ref().get()};
fn release(
pool: &Rc<UnsafeCell<HttpResponsePool>>, mut inner: Box<InnerHttpResponse>
) {
let pool = unsafe { &mut *pool.as_ref().get() };
if pool.0.len() < 128 {
inner.headers.clear();
inner.version = None;
@ -868,12 +900,12 @@ impl HttpResponsePool {
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
use time::Duration;
use http::{Method, Uri};
use http::header::{COOKIE, CONTENT_TYPE, HeaderValue};
use body::Binary;
use http;
use http::header::{HeaderValue, CONTENT_TYPE, COOKIE};
use http::{Method, Uri};
use std::str::FromStr;
use time::Duration;
#[test]
fn test_debug() {
@ -892,25 +924,37 @@ mod tests {
headers.insert(COOKIE, HeaderValue::from_static("cookie2=value2"));
let req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None);
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let cookies = req.cookies().unwrap();
let resp = HttpResponse::Ok()
.cookie(http::Cookie::build("name", "value")
.cookie(
http::Cookie::build("name", "value")
.domain("www.rust-lang.org")
.path("/test")
.http_only(true)
.max_age(Duration::days(1))
.finish())
.finish(),
)
.del_cookie(&cookies[0])
.finish();
let mut val: Vec<_> = resp.headers().get_all("Set-Cookie")
.iter().map(|v| v.to_str().unwrap().to_owned()).collect();
let mut val: Vec<_> = resp.headers()
.get_all("Set-Cookie")
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect();
val.sort();
assert!(val[0].starts_with("cookie2=; Max-Age=0;"));
assert_eq!(
val[1],"name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400");
val[1],
"name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400"
);
}
#[test]
@ -931,15 +975,21 @@ mod tests {
#[test]
fn test_force_close() {
let resp = HttpResponse::build(StatusCode::OK).force_close().finish();
let resp = HttpResponse::build(StatusCode::OK)
.force_close()
.finish();
assert!(!resp.keep_alive().unwrap())
}
#[test]
fn test_content_type() {
let resp = HttpResponse::build(StatusCode::OK)
.content_type("text/plain").body(Body::Empty);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain")
.content_type("text/plain")
.body(Body::Empty);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
"text/plain"
)
}
#[test]
@ -947,25 +997,29 @@ mod tests {
let resp = HttpResponse::build(StatusCode::OK).finish();
assert_eq!(resp.content_encoding(), None);
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
{
let resp = HttpResponse::build(StatusCode::OK)
.content_encoding(ContentEncoding::Br).finish();
.content_encoding(ContentEncoding::Br)
.finish();
assert_eq!(resp.content_encoding(), Some(ContentEncoding::Br));
}
let resp = HttpResponse::build(StatusCode::OK)
.content_encoding(ContentEncoding::Gzip).finish();
.content_encoding(ContentEncoding::Gzip)
.finish();
assert_eq!(resp.content_encoding(), Some(ContentEncoding::Gzip));
}
#[test]
fn test_json() {
let resp = HttpResponse::build(StatusCode::OK)
.json(vec!["v1", "v2", "v3"]);
let resp = HttpResponse::build(StatusCode::OK).json(vec!["v1", "v2", "v3"]);
let ct = resp.headers().get(CONTENT_TYPE).unwrap();
assert_eq!(ct, HeaderValue::from_static("application/json"));
assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")));
assert_eq!(
*resp.body(),
Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))
);
}
#[test]
@ -975,7 +1029,10 @@ mod tests {
.json(vec!["v1", "v2", "v3"]);
let ct = resp.headers().get(CONTENT_TYPE).unwrap();
assert_eq!(ct, HeaderValue::from_static("text/json"));
assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")));
assert_eq!(
*resp.body(),
Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))
);
}
impl Body {
@ -993,91 +1050,152 @@ mod tests {
let resp: HttpResponse = "test".into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test"));
let resp: HttpResponse = "test".respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test"));
let resp: HttpResponse = b"test".as_ref().into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref()));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(b"test".as_ref())
);
let resp: HttpResponse = b"test".as_ref().respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref()));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(b"test".as_ref())
);
let resp: HttpResponse = "test".to_owned().into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned()));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from("test".to_owned())
);
let resp: HttpResponse = "test".to_owned().respond_to(req.clone()).ok().unwrap();
let resp: HttpResponse = "test"
.to_owned()
.respond_to(req.clone())
.ok()
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned()));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from("test".to_owned())
);
let resp: HttpResponse = (&"test".to_owned()).into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned()));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(&"test".to_owned())
);
let resp: HttpResponse = (&"test".to_owned()).respond_to(req.clone()).ok().unwrap();
let resp: HttpResponse = (&"test".to_owned())
.respond_to(req.clone())
.ok()
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned()));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(&"test".to_owned())
);
let b = Bytes::from_static(b"test");
let resp: HttpResponse = b.into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test")));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(Bytes::from_static(b"test"))
);
let b = Bytes::from_static(b"test");
let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test")));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(Bytes::from_static(b"test"))
);
let b = BytesMut::from("test");
let resp: HttpResponse = b.into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test")));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(BytesMut::from("test"))
);
let b = BytesMut::from("test");
let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream"));
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test")));
assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(BytesMut::from("test"))
);
}
#[test]

View File

@ -1,13 +1,12 @@
use std::str::FromStr;
use http::header::{self, HeaderName};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use std::str::FromStr;
const X_FORWARDED_FOR: &str = "X-FORWARDED-FOR";
const X_FORWARDED_HOST: &str = "X-FORWARDED-HOST";
const X_FORWARDED_PROTO: &str = "X-FORWARDED-PROTO";
/// `HttpRequest` connection information
pub struct ConnectionInfo<'a> {
scheme: &'a str,
@ -17,7 +16,6 @@ pub struct ConnectionInfo<'a> {
}
impl<'a> ConnectionInfo<'a> {
/// Create *ConnectionInfo* instance for a request.
#[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))]
pub fn new<S>(req: &'a HttpRequest<S>) -> ConnectionInfo<'a> {
@ -55,8 +53,9 @@ impl<'a> ConnectionInfo<'a> {
// scheme
if scheme.is_none() {
if let Some(h) = req.headers().get(
HeaderName::from_str(X_FORWARDED_PROTO).unwrap()) {
if let Some(h) = req.headers()
.get(HeaderName::from_str(X_FORWARDED_PROTO).unwrap())
{
if let Ok(h) = h.to_str() {
scheme = h.split(',').next().map(|v| v.trim());
}
@ -75,7 +74,9 @@ impl<'a> ConnectionInfo<'a> {
// host
if host.is_none() {
if let Some(h) = req.headers().get(HeaderName::from_str(X_FORWARDED_HOST).unwrap()) {
if let Some(h) = req.headers()
.get(HeaderName::from_str(X_FORWARDED_HOST).unwrap())
{
if let Ok(h) = h.to_str() {
host = h.split(',').next().map(|v| v.trim());
}
@ -97,13 +98,15 @@ impl<'a> ConnectionInfo<'a> {
// remote addr
if remote.is_none() {
if let Some(h) = req.headers().get(
HeaderName::from_str(X_FORWARDED_FOR).unwrap()) {
if let Some(h) = req.headers()
.get(HeaderName::from_str(X_FORWARDED_FOR).unwrap())
{
if let Ok(h) = h.to_str() {
remote = h.split(',').next().map(|v| v.trim());
}
}
if remote.is_none() { // get peeraddr from socketaddr
if remote.is_none() {
// get peeraddr from socketaddr
peer = req.peer_addr().map(|addr| format!("{}", addr));
}
}
@ -176,7 +179,9 @@ mod tests {
req.headers_mut().insert(
header::FORWARDED,
HeaderValue::from_static(
"for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org"));
"for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org",
),
);
let info = ConnectionInfo::new(&req);
assert_eq!(info.scheme(), "https");
@ -185,7 +190,9 @@ mod tests {
let mut req = HttpRequest::default();
req.headers_mut().insert(
header::HOST, HeaderValue::from_static("rust-lang.org"));
header::HOST,
HeaderValue::from_static("rust-lang.org"),
);
let info = ConnectionInfo::new(&req);
assert_eq!(info.scheme(), "http");
@ -194,20 +201,26 @@ mod tests {
let mut req = HttpRequest::default();
req.headers_mut().insert(
HeaderName::from_str(X_FORWARDED_FOR).unwrap(), HeaderValue::from_static("192.0.2.60"));
HeaderName::from_str(X_FORWARDED_FOR).unwrap(),
HeaderValue::from_static("192.0.2.60"),
);
let info = ConnectionInfo::new(&req);
assert_eq!(info.remote(), Some("192.0.2.60"));
let mut req = HttpRequest::default();
req.headers_mut().insert(
HeaderName::from_str(X_FORWARDED_HOST).unwrap(), HeaderValue::from_static("192.0.2.60"));
HeaderName::from_str(X_FORWARDED_HOST).unwrap(),
HeaderValue::from_static("192.0.2.60"),
);
let info = ConnectionInfo::new(&req);
assert_eq!(info.host(), "192.0.2.60");
assert_eq!(info.remote(), None);
let mut req = HttpRequest::default();
req.headers_mut().insert(
HeaderName::from_str(X_FORWARDED_PROTO).unwrap(), HeaderValue::from_static("https"));
HeaderName::from_str(X_FORWARDED_PROTO).unwrap(),
HeaderValue::from_static("https"),
);
let info = ConnectionInfo::new(&req);
assert_eq!(info.scheme(), "https");
}

View File

@ -1,16 +1,16 @@
use bytes::{Bytes, BytesMut};
use futures::{Future, Poll, Stream};
use http::header::CONTENT_LENGTH;
use std::fmt;
use std::ops::{Deref, DerefMut};
use bytes::{Bytes, BytesMut};
use futures::{Poll, Future, Stream};
use http::header::CONTENT_LENGTH;
use mime;
use serde_json;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json;
use error::{Error, JsonPayloadError, PayloadError};
use handler::{Responder, FromRequest};
use handler::{FromRequest, Responder};
use http::StatusCode;
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
@ -18,80 +18,15 @@ use httpresponse::HttpResponse;
/// Json helper
///
/// Json can be used for two different purpose. First is for json response generation
/// and second is for extracting typed information from request's payload.
pub struct Json<T>(pub T);
impl<T> Json<T> {
/// Deconstruct to an inner value
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> fmt::Debug for Json<T> where T: fmt::Debug {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Json: {:?}", self.0)
}
}
impl<T> fmt::Display for Json<T> where T: fmt::Display {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
/// The `Json` type allows you to respond with well-formed JSON data: simply
/// return a value of type Json<T> where T is the type of a structure
/// to serialize into *JSON*. The type `T` must implement the `Serialize`
/// trait from *serde*.
/// Json can be used for two different purpose. First is for json response
/// generation and second is for extracting typed information from request's
/// payload.
///
/// ```rust
/// # extern crate actix_web;
/// # #[macro_use] extern crate serde_derive;
/// # use actix_web::*;
/// #
/// #[derive(Serialize)]
/// struct MyObj {
/// name: String,
/// }
/// To extract typed information from request's body, the type `T` must
/// implement the `Deserialize` trait from *serde*.
///
/// fn index(req: HttpRequest) -> Result<Json<MyObj>> {
/// Ok(Json(MyObj{name: req.match_info().query("name")?}))
/// }
/// # fn main() {}
/// ```
impl<T: Serialize> Responder for Json<T> {
type Item = HttpResponse;
type Error = Error;
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
let body = serde_json::to_string(&self.0)?;
Ok(req.build_response(StatusCode::OK)
.content_type("application/json")
.body(body))
}
}
/// To extract typed information from request's body, the type `T` must implement the
/// `Deserialize` trait from *serde*.
///
/// [**JsonConfig**](dev/struct.JsonConfig.html) allows to configure extraction process.
/// [**JsonConfig**](dev/struct.JsonConfig.html) allows to configure extraction
/// process.
///
/// ## Example
///
@ -116,11 +51,88 @@ impl<T: Serialize> Responder for Json<T> {
/// |r| r.method(http::Method::POST).with(index)); // <- use `with` extractor
/// }
/// ```
///
/// The `Json` type allows you to respond with well-formed JSON data: simply
/// return a value of type Json<T> where T is the type of a structure
/// to serialize into *JSON*. The type `T` must implement the `Serialize`
/// trait from *serde*.
///
/// ```rust
/// # extern crate actix_web;
/// # #[macro_use] extern crate serde_derive;
/// # use actix_web::*;
/// #
/// #[derive(Serialize)]
/// struct MyObj {
/// name: String,
/// }
///
/// fn index(req: HttpRequest) -> Result<Json<MyObj>> {
/// Ok(Json(MyObj{name: req.match_info().query("name")?}))
/// }
/// # fn main() {}
/// ```
pub struct Json<T>(pub T);
impl<T> Json<T> {
/// Deconstruct to an inner value
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> fmt::Debug for Json<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Json: {:?}", self.0)
}
}
impl<T> fmt::Display for Json<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl<T: Serialize> Responder for Json<T> {
type Item = HttpResponse;
type Error = Error;
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
let body = serde_json::to_string(&self.0)?;
Ok(req.build_response(StatusCode::OK)
.content_type("application/json")
.body(body))
}
}
impl<T, S> FromRequest<S> for Json<T>
where T: DeserializeOwned + 'static, S: 'static
where
T: DeserializeOwned + 'static,
S: 'static,
{
type Config = JsonConfig;
type Result = Box<Future<Item=Self, Error=Error>>;
type Result = Box<Future<Item = Self, Error = Error>>;
#[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
@ -128,7 +140,8 @@ impl<T, S> FromRequest<S> for Json<T>
JsonBody::new(req.clone())
.limit(cfg.limit)
.from_err()
.map(Json))
.map(Json),
)
}
}
@ -163,7 +176,6 @@ pub struct JsonConfig {
}
impl JsonConfig {
/// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit;
@ -173,7 +185,7 @@ impl JsonConfig {
impl Default for JsonConfig {
fn default() -> Self {
JsonConfig{limit: 262_144}
JsonConfig { limit: 262_144 }
}
}
@ -208,17 +220,16 @@ impl Default for JsonConfig {
/// }
/// # fn main() {}
/// ```
pub struct JsonBody<T, U: DeserializeOwned>{
pub struct JsonBody<T, U: DeserializeOwned> {
limit: usize,
req: Option<T>,
fut: Option<Box<Future<Item=U, Error=JsonPayloadError>>>,
fut: Option<Box<Future<Item = U, Error = JsonPayloadError>>>,
}
impl<T, U: DeserializeOwned> JsonBody<T, U> {
/// Create `JsonBody` for request.
pub fn new(req: T) -> Self {
JsonBody{
JsonBody {
limit: 262_144,
req: Some(req),
fut: None,
@ -233,7 +244,8 @@ impl<T, U: DeserializeOwned> JsonBody<T, U> {
}
impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
where T: HttpMessage + Stream<Item=Bytes, Error=PayloadError> + 'static
where
T: HttpMessage + Stream<Item = Bytes, Error = PayloadError> + 'static,
{
type Item = U;
type Error = JsonPayloadError;
@ -259,7 +271,7 @@ impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
false
};
if !json {
return Err(JsonPayloadError::ContentType)
return Err(JsonPayloadError::ContentType);
}
let limit = self.limit;
@ -276,7 +288,10 @@ impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
self.fut = Some(Box::new(fut));
}
self.fut.as_mut().expect("JsonBody could not be used second time").poll()
self.fut
.as_mut()
.expect("JsonBody could not be used second time")
.poll()
}
}
@ -284,11 +299,11 @@ impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
mod tests {
use super::*;
use bytes::Bytes;
use http::header;
use futures::Async;
use http::header;
use with::{With, ExtractorConfig};
use handler::Handler;
use with::{ExtractorConfig, With};
impl PartialEq for JsonPayloadError {
fn eq(&self, other: &JsonPayloadError) -> bool {
@ -313,59 +328,100 @@ mod tests {
#[test]
fn test_json() {
let json = Json(MyObject{name: "test".to_owned()});
let json = Json(MyObject {
name: "test".to_owned(),
});
let resp = json.respond_to(HttpRequest::default()).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "application/json");
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
}
#[test]
fn test_json_body() {
let req = HttpRequest::default();
let mut json = req.json::<MyObject>();
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType);
assert_eq!(
json.poll().err().unwrap(),
JsonPayloadError::ContentType
);
let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"));
req.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"),
);
let mut json = req.json::<MyObject>();
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType);
assert_eq!(
json.poll().err().unwrap(),
JsonPayloadError::ContentType
);
let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"));
req.headers_mut().insert(header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"));
req.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
req.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"),
);
let mut json = req.json::<MyObject>().limit(100);
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow);
let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"));
req.headers_mut().insert(header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"));
req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}"));
req.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
req.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
);
req.payload_mut()
.unread_data(Bytes::from_static(b"{\"name\": \"test\"}"));
let mut json = req.json::<MyObject>();
assert_eq!(json.poll().ok().unwrap(),
Async::Ready(MyObject{name: "test".to_owned()}));
assert_eq!(
json.poll().ok().unwrap(),
Async::Ready(MyObject {
name: "test".to_owned()
})
);
}
#[test]
fn test_with_json() {
let mut cfg = ExtractorConfig::<_, Json<MyObject>>::default();
cfg.limit(4096);
let mut handler = With::new(|data: Json<MyObject>| {data}, cfg);
let mut handler = With::new(|data: Json<MyObject>| data, cfg);
let req = HttpRequest::default();
let err = handler.handle(req).as_response().unwrap().error().is_some();
let err = handler
.handle(req)
.as_response()
.unwrap()
.error()
.is_some();
assert!(err);
let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"));
req.headers_mut().insert(header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"));
req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}"));
let ok = handler.handle(req).as_response().unwrap().error().is_none();
req.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
req.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
);
req.payload_mut()
.unread_data(Bytes::from_static(b"{\"name\": \"test\"}"));
let ok = handler
.handle(req)
.as_response()
.unwrap()
.error()
.is_none();
assert!(ok)
}
}

View File

@ -64,17 +64,17 @@
#![cfg_attr(actix_nightly, feature(
specialization, // for impl ErrorResponse for std::error::Error
))]
#![cfg_attr(feature = "cargo-clippy", allow(
decimal_literal_representation,suspicious_arithmetic_impl,))]
#![cfg_attr(feature = "cargo-clippy",
allow(decimal_literal_representation, suspicious_arithmetic_impl))]
#[macro_use]
extern crate log;
extern crate time;
extern crate base64;
extern crate bytes;
extern crate byteorder;
extern crate sha1;
extern crate bytes;
extern crate regex;
extern crate sha1;
extern crate time;
#[macro_use]
extern crate bitflags;
#[macro_use]
@ -83,46 +83,49 @@ extern crate failure;
extern crate lazy_static;
#[macro_use]
extern crate futures;
extern crate futures_cpupool;
extern crate tokio_io;
extern crate tokio_core;
extern crate mio;
extern crate net2;
extern crate cookie;
extern crate futures_cpupool;
extern crate http as modhttp;
extern crate httparse;
extern crate http_range;
extern crate httparse;
extern crate language_tags;
extern crate libc;
extern crate mime;
extern crate mime_guess;
extern crate language_tags;
extern crate mio;
extern crate net2;
extern crate rand;
extern crate tokio_core;
extern crate tokio_io;
extern crate url;
extern crate libc;
#[macro_use] extern crate serde;
extern crate serde_json;
extern crate serde_urlencoded;
extern crate flate2;
#[cfg(feature="brotli")]
#[macro_use]
extern crate serde;
#[cfg(feature = "brotli")]
extern crate brotli2;
extern crate encoding;
extern crate percent_encoding;
extern crate smallvec;
extern crate num_cpus;
extern crate flate2;
extern crate h2 as http2;
extern crate num_cpus;
extern crate percent_encoding;
extern crate serde_json;
extern crate serde_urlencoded;
extern crate smallvec;
extern crate trust_dns_resolver;
#[macro_use] extern crate actix;
#[macro_use]
extern crate actix;
#[cfg(test)]
#[macro_use] extern crate serde_derive;
#[macro_use]
extern crate serde_derive;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
extern crate native_tls;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
extern crate tokio_tls;
#[cfg(feature="openssl")]
#[cfg(feature = "openssl")]
extern crate openssl;
#[cfg(feature="openssl")]
#[cfg(feature = "openssl")]
extern crate tokio_openssl;
mod application;
@ -138,33 +141,33 @@ mod httprequest;
mod httpresponse;
mod info;
mod json;
mod route;
mod router;
mod resource;
mod param;
mod payload;
mod pipeline;
mod resource;
mod route;
mod router;
mod with;
pub mod client;
pub mod fs;
pub mod ws;
pub mod error;
pub mod multipart;
pub mod fs;
pub mod middleware;
pub mod multipart;
pub mod pred;
pub mod test;
pub mod server;
pub use extractor::{Path, Form, Query};
pub use error::{Error, Result, ResponseError};
pub use body::{Body, Binary};
pub use json::Json;
pub mod test;
pub mod ws;
pub use application::App;
pub use body::{Binary, Body};
pub use context::HttpContext;
pub use error::{Error, ResponseError, Result};
pub use extractor::{Form, Path, Query};
pub use handler::{AsyncResponder, Either, FromRequest, FutureResponse, Responder, State};
pub use httpmessage::HttpMessage;
pub use httprequest::HttpRequest;
pub use httpresponse::HttpResponse;
pub use handler::{Either, Responder, AsyncResponder, FromRequest, FutureResponse, State};
pub use context::HttpContext;
pub use json::Json;
#[doc(hidden)]
pub mod httpcodes;
@ -173,39 +176,39 @@ pub mod httpcodes;
#[allow(deprecated)]
pub use application::Application;
#[cfg(feature="openssl")]
#[cfg(feature = "openssl")]
pub(crate) const HAS_OPENSSL: bool = true;
#[cfg(not(feature="openssl"))]
#[cfg(not(feature = "openssl"))]
pub(crate) const HAS_OPENSSL: bool = false;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
pub(crate) const HAS_TLS: bool = true;
#[cfg(not(feature="tls"))]
#[cfg(not(feature = "tls"))]
pub(crate) const HAS_TLS: bool = false;
pub mod dev {
//! The `actix-web` prelude for library developers
//!
//! The purpose of this module is to alleviate imports of many common actix traits
//! by adding a glob import to the top of actix heavy modules:
//!
//! ```
//! # #![allow(unused_imports)]
//! use actix_web::dev::*;
//! ```
//! The `actix-web` prelude for library developers
//!
//! The purpose of this module is to alleviate imports of many common actix
//! traits by adding a glob import to the top of actix heavy modules:
//!
//! ```
//! # #![allow(unused_imports)]
//! use actix_web::dev::*;
//! ```
pub use body::BodyStream;
pub use context::Drain;
pub use json::{JsonBody, JsonConfig};
pub use info::ConnectionInfo;
pub use handler::{Handler, Reply};
pub use extractor::{FormConfig, PayloadConfig};
pub use route::Route;
pub use router::{Router, Resource, ResourceType};
pub use resource::ResourceHandler;
pub use param::{FromParam, Params};
pub use httpmessage::{UrlEncoded, MessageBody};
pub use handler::{Handler, Reply};
pub use httpmessage::{MessageBody, UrlEncoded};
pub use httpresponse::HttpResponseBuilder;
pub use info::ConnectionInfo;
pub use json::{JsonBody, JsonConfig};
pub use param::{FromParam, Params};
pub use resource::ResourceHandler;
pub use route::Route;
pub use router::{Resource, ResourceType, Router};
}
pub mod http {
@ -215,15 +218,15 @@ pub mod http {
pub use modhttp::{Method, StatusCode, Version};
#[doc(hidden)]
pub use modhttp::{uri, Uri, Error, Extensions, HeaderMap, HttpTryFrom};
pub use modhttp::{uri, Error, Extensions, HeaderMap, HttpTryFrom, Uri};
pub use http_range::HttpRange;
pub use cookie::{Cookie, CookieBuilder};
pub use http_range::HttpRange;
pub use helpers::NormalizePath;
pub mod header {
pub use ::header::*;
pub use header::*;
}
pub use header::ContentEncoding;
pub use httpresponse::ConnectionType;

View File

@ -7,7 +7,8 @@
//!
//! 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
//! 2. Use any of the builder methods to set fields in the backend.
//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend.
//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
//! constructed backend.
//!
//! Cors middleware could be used as parameter for `App::middleware()` or
//! `ResourceHandler::middleware()` methods. But you have to use
@ -40,65 +41,69 @@
//! .register());
//! }
//! ```
//! In this example custom *CORS* middleware get registered for "/index.html" endpoint.
//! In this example custom *CORS* middleware get registered for "/index.html"
//! endpoint.
//!
//! Cors middleware automatically handle *OPTIONS* preflight request.
use std::collections::HashSet;
use std::iter::FromIterator;
use std::rc::Rc;
use http::{self, Method, HttpTryFrom, Uri, StatusCode};
use http::header::{self, HeaderName, HeaderValue};
use http::{self, HttpTryFrom, Method, StatusCode, Uri};
use application::App;
use error::{Result, ResponseError};
use resource::ResourceHandler;
use error::{ResponseError, Result};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Response, Started};
use resource::ResourceHandler;
/// A set of errors that can occur during processing CORS
#[derive(Debug, Fail)]
pub enum CorsError {
/// The HTTP request header `Origin` is required but was not provided
#[fail(display="The HTTP request header `Origin` is required but was not provided")]
#[fail(display = "The HTTP request header `Origin` is required but was not provided")]
MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly.
#[fail(display="The HTTP request header `Origin` could not be parsed correctly.")]
#[fail(display = "The HTTP request header `Origin` could not be parsed correctly.")]
BadOrigin,
/// The request header `Access-Control-Request-Method` is required but is missing
#[fail(display="The request header `Access-Control-Request-Method` is required but is missing")]
/// The request header `Access-Control-Request-Method` is required but is
/// missing
#[fail(display = "The request header `Access-Control-Request-Method` is required but is missing")]
MissingRequestMethod,
/// The request header `Access-Control-Request-Method` has an invalid value
#[fail(display="The request header `Access-Control-Request-Method` has an invalid value")]
#[fail(display = "The request header `Access-Control-Request-Method` has an invalid value")]
BadRequestMethod,
/// The request header `Access-Control-Request-Headers` has an invalid value
#[fail(display="The request header `Access-Control-Request-Headers` has an invalid value")]
/// The request header `Access-Control-Request-Headers` has an invalid
/// value
#[fail(display = "The request header `Access-Control-Request-Headers` has an invalid value")]
BadRequestHeaders,
/// The request header `Access-Control-Request-Headers` is required but is missing.
#[fail(display="The request header `Access-Control-Request-Headers` is required but is
/// The request header `Access-Control-Request-Headers` is required but is
/// missing.
#[fail(display = "The request header `Access-Control-Request-Headers` is required but is
missing")]
MissingRequestHeaders,
/// Origin is not allowed to make this request
#[fail(display="Origin is not allowed to make this request")]
#[fail(display = "Origin is not allowed to make this request")]
OriginNotAllowed,
/// Requested method is not allowed
#[fail(display="Requested method is not allowed")]
#[fail(display = "Requested method is not allowed")]
MethodNotAllowed,
/// One or more headers requested are not allowed
#[fail(display="One or more headers requested are not allowed")]
#[fail(display = "One or more headers requested are not allowed")]
HeadersNotAllowed,
}
impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse {
HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self))
}
}
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed).
/// An enum signifying that some of type T is allowed, or `All` (everything is
/// allowed).
///
/// `Default` is implemented for this enum and is `All`.
#[derive(Clone, Debug, Eq, PartialEq)]
@ -166,9 +171,16 @@ impl Default for Cors {
origins: AllOrSome::default(),
origins_str: None,
methods: HashSet::from_iter(
vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT,
Method::PATCH, Method::DELETE].into_iter()),
vec![
Method::GET,
Method::HEAD,
Method::POST,
Method::OPTIONS,
Method::PUT,
Method::PATCH,
Method::DELETE,
].into_iter(),
),
headers: AllOrSome::All,
expose_hdrs: None,
max_age: None,
@ -177,7 +189,9 @@ impl Default for Cors {
supports_credentials: false,
vary_header: true,
};
Cors{inner: Rc::new(inner)}
Cors {
inner: Rc::new(inner),
}
}
}
@ -247,11 +261,13 @@ impl Cors {
/// This method register cors middleware with resource and
/// adds route for *OPTIONS* preflight requests.
///
/// It is possible to register *Cors* middleware with `ResourceHandler::middleware()`
/// method, but in that case *Cors* middleware wont be able to handle *OPTIONS*
/// requests.
/// It is possible to register *Cors* middleware with
/// `ResourceHandler::middleware()` method, but in that case *Cors*
/// middleware wont be able to handle *OPTIONS* requests.
pub fn register<S: 'static>(self, resource: &mut ResourceHandler<S>) {
resource.method(Method::OPTIONS).h(|_| HttpResponse::Ok());
resource
.method(Method::OPTIONS)
.h(|_| HttpResponse::Ok());
resource.middleware(self);
}
@ -260,28 +276,32 @@ impl Cors {
if let Ok(origin) = hdr.to_str() {
return match self.inner.origins {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => {
allowed_origins
.get(origin)
.and_then(|_| Some(()))
.ok_or_else(|| CorsError::OriginNotAllowed)
}
AllOrSome::Some(ref allowed_origins) => allowed_origins
.get(origin)
.and_then(|_| Some(()))
.ok_or_else(|| CorsError::OriginNotAllowed),
};
}
Err(CorsError::BadOrigin)
} else {
return match self.inner.origins {
AllOrSome::All => Ok(()),
_ => Err(CorsError::MissingOrigin)
}
_ => Err(CorsError::MissingOrigin),
};
}
}
fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
fn validate_allowed_method<S>(
&self, req: &mut HttpRequest<S>
) -> Result<(), CorsError> {
if let Some(hdr) = req.headers()
.get(header::ACCESS_CONTROL_REQUEST_METHOD)
{
if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) {
return self.inner.methods.get(&method)
return self.inner
.methods
.get(&method)
.and_then(|_| Some(()))
.ok_or_else(|| CorsError::MethodNotAllowed);
}
@ -292,24 +312,28 @@ impl Cors {
}
}
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> {
fn validate_allowed_headers<S>(
&self, req: &mut HttpRequest<S>
) -> Result<(), CorsError> {
match self.inner.headers {
AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_headers) => {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
if let Some(hdr) = req.headers()
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
{
if let Ok(headers) = hdr.to_str() {
let mut hdrs = HashSet::new();
for hdr in headers.split(',') {
match HeaderName::try_from(hdr.trim()) {
Ok(hdr) => hdrs.insert(hdr),
Err(_) => return Err(CorsError::BadRequestHeaders)
Err(_) => return Err(CorsError::BadRequestHeaders),
};
}
if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) {
return Err(CorsError::HeadersNotAllowed)
return Err(CorsError::HeadersNotAllowed);
}
return Ok(())
return Ok(());
}
Err(CorsError::BadRequestHeaders)
} else {
@ -321,7 +345,6 @@ impl Cors {
}
impl<S> Middleware<S> for Cors {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
if self.inner.preflight && Method::OPTIONS == *req.method() {
self.validate_origin(req)?;
@ -330,9 +353,17 @@ impl<S> Middleware<S> for Cors {
// allowed headers
let headers = if let Some(headers) = self.inner.headers.as_ref() {
Some(HeaderValue::try_from(&headers.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap())
} else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
Some(
HeaderValue::try_from(
&headers
.iter()
.fold(String::new(), |s, v| s + "," + v.as_str())
.as_str()[1..],
).unwrap(),
)
} else if let Some(hdr) = req.headers()
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
{
Some(hdr.clone())
} else {
None
@ -342,31 +373,44 @@ impl<S> Middleware<S> for Cors {
HttpResponse::Ok()
.if_some(self.inner.max_age.as_ref(), |max_age, resp| {
let _ = resp.header(
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());})
header::ACCESS_CONTROL_MAX_AGE,
format!("{}", max_age).as_str(),
);
})
.if_some(headers, |headers, resp| {
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); })
let _ =
resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
})
.if_true(self.inner.origins.is_all(), |resp| {
if self.inner.send_wildcard {
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
} else {
let origin = req.headers().get(header::ORIGIN).unwrap();
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
}
})
.if_true(self.inner.origins.is_some(), |resp| {
resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.inner.origins_str.as_ref().unwrap().clone());
self.inner.origins_str.as_ref().unwrap().clone(),
);
})
.if_true(self.inner.supports_credentials, |resp| {
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
})
.header(
header::ACCESS_CONTROL_ALLOW_METHODS,
&self.inner.methods.iter().fold(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..])
.finish()))
&self.inner
.methods
.iter()
.fold(String::new(), |s, v| s + "," + v.as_str())
.as_str()[1..],
)
.finish(),
))
} else {
self.validate_origin(req)?;
@ -374,32 +418,40 @@ impl<S> Middleware<S> for Cors {
}
}
fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
fn response(
&self, req: &mut HttpRequest<S>, mut resp: HttpResponse
) -> Result<Response> {
match self.inner.origins {
AllOrSome::All => {
if self.inner.send_wildcard {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("*"),
);
} else if let Some(origin) = req.headers().get(header::ORIGIN) {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
resp.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
}
}
AllOrSome::Some(_) => {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.inner.origins_str.as_ref().unwrap().clone());
self.inner.origins_str.as_ref().unwrap().clone(),
);
}
}
if let Some(ref expose) = self.inner.expose_hdrs {
resp.headers_mut().insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::try_from(expose.as_str()).unwrap());
HeaderValue::try_from(expose.as_str()).unwrap(),
);
}
if self.inner.supports_credentials {
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if self.inner.vary_header {
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
@ -416,13 +468,15 @@ impl<S> Middleware<S> for Cors {
}
}
/// Structure that follows the builder pattern for building `Cors` middleware structs.
/// Structure that follows the builder pattern for building `Cors` middleware
/// structs.
///
/// To construct a cors:
///
/// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
/// 2. Use any of the builder methods to set fields in the backend.
/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend.
/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
/// constructed backend.
///
/// # Example
///
@ -442,7 +496,7 @@ impl<S> Middleware<S> for Cors {
/// .finish();
/// # }
/// ```
pub struct CorsBuilder<S=()> {
pub struct CorsBuilder<S = ()> {
cors: Option<Inner>,
methods: bool,
error: Option<http::Error>,
@ -451,26 +505,26 @@ pub struct CorsBuilder<S=()> {
app: Option<App<S>>,
}
fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<http::Error>)
-> Option<&'a mut Inner>
{
fn cors<'a>(
parts: &'a mut Option<Inner>, err: &Option<http::Error>
) -> Option<&'a mut Inner> {
if err.is_some() {
return None
return None;
}
parts.as_mut()
}
impl<S: 'static> CorsBuilder<S> {
/// Add an origin that are allowed to make requests.
/// Will be verified against the `Origin` request header.
///
/// When `All` is set, and `send_wildcard` is set, "*" will be sent in
/// the `Access-Control-Allow-Origin` response header. Otherwise, the client's `Origin` request
/// header will be echoed back in the `Access-Control-Allow-Origin` response header.
/// the `Access-Control-Allow-Origin` response header. Otherwise, the
/// client's `Origin` request header will be echoed back in the
/// `Access-Control-Allow-Origin` response header.
///
/// When `Some` is set, the client's `Origin` request header will be checked in a
/// case-sensitive manner.
/// When `Some` is set, the client's `Origin` request header will be
/// checked in a case-sensitive manner.
///
/// This is the `list of origins` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
@ -497,15 +551,17 @@ impl<S: 'static> CorsBuilder<S> {
self
}
/// Set a list of methods which the allowed origins are allowed to access for
/// requests.
/// Set a list of methods which the allowed origins are allowed to access
/// for requests.
///
/// This is the `list of methods` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
pub fn allowed_methods<U, M>(&mut self, methods: U) -> &mut CorsBuilder<S>
where U: IntoIterator<Item=M>, Method: HttpTryFrom<M>
where
U: IntoIterator<Item = M>,
Method: HttpTryFrom<M>,
{
self.methods = true;
if let Some(cors) = cors(&mut self.cors, &self.error) {
@ -513,20 +569,21 @@ impl<S: 'static> CorsBuilder<S> {
match Method::try_from(m) {
Ok(method) => {
cors.methods.insert(method);
},
}
Err(e) => {
self.error = Some(e.into());
break
break;
}
}
};
}
}
self
}
/// Set an allowed header
pub fn allowed_header<H>(&mut self, header: H) -> &mut CorsBuilder<S>
where HeaderName: HttpTryFrom<H>
where
HeaderName: HttpTryFrom<H>,
{
if let Some(cors) = cors(&mut self.cors, &self.error) {
match HeaderName::try_from(header) {
@ -547,15 +604,18 @@ impl<S: 'static> CorsBuilder<S> {
/// Set a list of header field names which can be used when
/// this resource is accessed by allowed origins.
///
/// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
/// will be echoed back in the `Access-Control-Allow-Headers` header.
/// If `All` is set, whatever is requested by the client in
/// `Access-Control-Request-Headers` will be echoed back in the
/// `Access-Control-Allow-Headers` header.
///
/// This is the `list of headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// Defaults to `All`.
pub fn allowed_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S>
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
where
U: IntoIterator<Item = H>,
HeaderName: HttpTryFrom<H>,
{
if let Some(cors) = cors(&mut self.cors, &self.error) {
for h in headers {
@ -570,32 +630,35 @@ impl<S: 'static> CorsBuilder<S> {
}
Err(e) => {
self.error = Some(e.into());
break
break;
}
}
};
}
}
self
}
/// Set a list of headers which are safe to expose to the API of a CORS API specification.
/// This corresponds to the `Access-Control-Expose-Headers` response header.
/// Set a list of headers which are safe to expose to the API of a CORS API
/// specification. This corresponds to the
/// `Access-Control-Expose-Headers` response header.
///
/// This is the `list of exposed headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// This defaults to an empty set.
pub fn expose_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S>
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H>
where
U: IntoIterator<Item = H>,
HeaderName: HttpTryFrom<H>,
{
for h in headers {
match HeaderName::try_from(h) {
Ok(method) => {
self.expose_hdrs.insert(method);
},
}
Err(e) => {
self.error = Some(e.into());
break
break;
}
}
}
@ -615,16 +678,17 @@ impl<S: 'static> CorsBuilder<S> {
/// Set a wildcard origins
///
/// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard
/// `Access-Control-Allow-Origin` response header is sent, rather than the requests
/// `Origin` header.
/// If send wildcard is set and the `allowed_origins` parameter is `All`, a
/// wildcard `Access-Control-Allow-Origin` response header is sent,
/// rather than the requests `Origin` header.
///
/// This is the `supports credentials flag` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
///
/// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and
/// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result
/// in an `Error::CredentialsWithWildcardOrigin` error during actix launch or runtime.
/// This **CANNOT** be used in conjunction with `allowed_origins` set to
/// `All` and `allow_credentials` set to `true`. Depending on the mode
/// of usage, this will either result in an `Error::
/// CredentialsWithWildcardOrigin` error during actix launch or runtime.
///
/// Defaults to `false`.
pub fn send_wildcard(&mut self) -> &mut CorsBuilder<S> {
@ -636,11 +700,12 @@ impl<S: 'static> CorsBuilder<S> {
/// Allows users to make authenticated requests
///
/// If true, injects the `Access-Control-Allow-Credentials` header in responses.
/// This allows cookies and credentials to be submitted across domains.
/// If true, injects the `Access-Control-Allow-Credentials` header in
/// responses. This allows cookies and credentials to be submitted
/// across domains.
///
/// This option cannot be used in conjunction with an `allowed_origin` set to `All`
/// and `send_wildcards` set to `true`.
/// This option cannot be used in conjunction with an `allowed_origin` set
/// to `All` and `send_wildcards` set to `true`.
///
/// Defaults to `false`.
///
@ -713,7 +778,8 @@ impl<S: 'static> CorsBuilder<S> {
/// }
/// ```
pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut CorsBuilder<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static
where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{
// add resource handler
let mut handler = ResourceHandler::default();
@ -725,9 +791,15 @@ impl<S: 'static> CorsBuilder<S> {
fn construct(&mut self) -> Cors {
if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD,
Method::POST, Method::OPTIONS, Method::PUT,
Method::PATCH, Method::DELETE]);
self.allowed_methods(vec![
Method::GET,
Method::HEAD,
Method::POST,
Method::OPTIONS,
Method::PUT,
Method::PATCH,
Method::DELETE,
]);
}
if let Some(e) = self.error.take() {
@ -741,16 +813,23 @@ impl<S: 'static> CorsBuilder<S> {
}
if let AllOrSome::Some(ref origins) = cors.origins {
let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v));
let s = origins
.iter()
.fold(String::new(), |s, v| s + &format!("{}", v));
cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap());
}
if !self.expose_hdrs.is_empty() {
cors.expose_hdrs = Some(
self.expose_hdrs.iter().fold(
String::new(), |s, v| s + v.as_str())[1..].to_owned());
self.expose_hdrs
.iter()
.fold(String::new(), |s, v| s + v.as_str())[1..]
.to_owned(),
);
}
Cors {
inner: Rc::new(cors),
}
Cors{inner: Rc::new(cors)}
}
/// Finishes building and returns the built `Cors` instance.
@ -758,13 +837,16 @@ impl<S: 'static> CorsBuilder<S> {
/// This method panics in case of any configuration error.
pub fn finish(&mut self) -> Cors {
if !self.resources.is_empty() {
panic!("CorsBuilder::resource() was used,
to construct CORS `.register(app)` method should be used");
panic!(
"CorsBuilder::resource() was used,
to construct CORS `.register(app)` method should be used"
);
}
self.construct()
}
/// Finishes building Cors middleware and register middleware for application
/// Finishes building Cors middleware and register middleware for
/// application
///
/// This method panics in case of any configuration error or if non of
/// resources are registered.
@ -774,8 +856,9 @@ impl<S: 'static> CorsBuilder<S> {
}
let cors = self.construct();
let mut app = self.app.take().expect(
"CorsBuilder has to be constructed with Cors::for_app(app)");
let mut app = self.app
.take()
.expect("CorsBuilder has to be constructed with Cors::for_app(app)");
// register resources
for (path, mut resource) in self.resources.drain(..) {
@ -787,7 +870,6 @@ impl<S: 'static> CorsBuilder<S> {
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -845,8 +927,8 @@ mod tests {
#[test]
fn validate_origin_allows_all_origins() {
let cors = Cors::default();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com").finish();
let mut req =
TestRequest::with_header("Origin", "https://www.example.com").finish();
assert!(cors.start(&mut req).ok().unwrap().is_done())
}
@ -861,8 +943,7 @@ mod tests {
.allowed_header(header::CONTENT_TYPE)
.finish();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com")
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.finish();
@ -877,23 +958,35 @@ mod tests {
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
.header(
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
)
.method(Method::OPTIONS)
.finish();
let resp = cors.start(&mut req).unwrap().response();
assert_eq!(
&b"*"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!(
&b"3600"[..],
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes());
resp.headers()
.get(header::ACCESS_CONTROL_MAX_AGE)
.unwrap()
.as_bytes()
);
//assert_eq!(
// &b"authorization,accept,content-type"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes());
//assert_eq!(
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().
// as_bytes()); assert_eq!(
// &b"POST,GET,OPTIONS"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes());
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().
// as_bytes());
Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
assert!(cors.start(&mut req).unwrap().is_done());
@ -903,7 +996,8 @@ mod tests {
#[should_panic(expected = "MissingOrigin")]
fn test_validate_missing_origin() {
let cors = Cors::build()
.allowed_origin("https://www.example.com").finish();
.allowed_origin("https://www.example.com")
.finish();
let mut req = HttpRequest::default();
cors.start(&mut req).unwrap();
@ -913,7 +1007,8 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() {
let cors = Cors::build()
.allowed_origin("https://www.example.com").finish();
.allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET)
@ -924,7 +1019,8 @@ mod tests {
#[test]
fn test_validate_origin() {
let cors = Cors::build()
.allowed_origin("https://www.example.com").finish();
.allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET)
@ -940,16 +1036,23 @@ mod tests {
let mut req = TestRequest::default().method(Method::GET).finish();
let resp: HttpResponse = HttpResponse::Ok().into();
let resp = cors.response(&mut req, resp).unwrap().response();
assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
assert!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none()
);
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com")
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.finish();
let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!(
&b"https://www.example.com"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
}
#[test]
@ -963,8 +1066,7 @@ mod tests {
.allowed_header(header::CONTENT_TYPE)
.finish();
let mut req = TestRequest::with_header(
"Origin", "https://www.example.com")
let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.finish();
@ -972,10 +1074,15 @@ mod tests {
let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!(
&b"*"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!(
&b"Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes());
resp.headers().get(header::VARY).unwrap().as_bytes()
);
let resp: HttpResponse = HttpResponse::Ok()
.header(header::VARY, "Accept")
@ -983,7 +1090,8 @@ mod tests {
let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!(
&b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes());
resp.headers().get(header::VARY).unwrap().as_bytes()
);
let cors = Cors::build()
.disable_vary_header()
@ -993,25 +1101,33 @@ mod tests {
let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!(
&b"https://www.example.com"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes());
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
}
#[test]
fn cors_resource() {
let mut srv = test::TestServer::with_factory(
|| App::new()
.configure(
|app| Cors::for_app(app)
.allowed_origin("https://www.example.com")
.resource("/test", |r| r.f(|_| HttpResponse::Ok()))
.register()));
let mut srv = test::TestServer::with_factory(|| {
App::new().configure(|app| {
Cors::for_app(app)
.allowed_origin("https://www.example.com")
.resource("/test", |r| r.f(|_| HttpResponse::Ok()))
.register()
})
});
let request = srv.get().uri(srv.url("/test")).finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let request = srv.get().uri(srv.url("/test"))
.header("ORIGIN", "https://www.example.com").finish().unwrap();
let request = srv.get()
.uri(srv.url("/test"))
.header("ORIGIN", "https://www.example.com")
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::OK);
}

View File

@ -48,8 +48,8 @@ use std::borrow::Cow;
use std::collections::HashSet;
use bytes::Bytes;
use error::{Result, ResponseError};
use http::{HeaderMap, HttpTryFrom, Uri, header};
use error::{ResponseError, Result};
use http::{header, HeaderMap, HttpTryFrom, Uri};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
@ -59,13 +59,13 @@ use middleware::{Middleware, Started};
#[derive(Debug, Fail)]
pub enum CsrfError {
/// The HTTP request header `Origin` was required but not provided.
#[fail(display="Origin header required")]
#[fail(display = "Origin header required")]
MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly.
#[fail(display="Could not parse Origin header")]
#[fail(display = "Could not parse Origin header")]
BadOrigin,
/// The cross-site request was denied.
#[fail(display="Cross-site request denied")]
#[fail(display = "Cross-site request denied")]
CsrDenied,
}
@ -80,15 +80,14 @@ fn uri_origin(uri: &Uri) -> Option<String> {
(Some(scheme), Some(host), Some(port)) => {
Some(format!("{}://{}:{}", scheme, host, port))
}
(Some(scheme), Some(host), None) => {
Some(format!("{}://{}", scheme, host))
}
_ => None
(Some(scheme), Some(host), None) => Some(format!("{}://{}", scheme, host)),
_ => None,
}
}
fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
headers.get(header::ORIGIN)
headers
.get(header::ORIGIN)
.map(|origin| {
origin
.to_str()
@ -96,15 +95,14 @@ fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
.map(|o| o.into())
})
.or_else(|| {
headers.get(header::REFERER)
.map(|referer| {
Uri::try_from(Bytes::from(referer.as_bytes()))
.ok()
.as_ref()
.and_then(uri_origin)
.ok_or(CsrfError::BadOrigin)
.map(|o| o.into())
})
headers.get(header::REFERER).map(|referer| {
Uri::try_from(Bytes::from(referer.as_bytes()))
.ok()
.as_ref()
.and_then(uri_origin)
.ok_or(CsrfError::BadOrigin)
.map(|o| o.into())
})
})
}
@ -194,7 +192,8 @@ impl CsrfFilter {
let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) {
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with"))
{
Ok(())
} else if let Some(header) = origin(req.headers()) {
match header {
@ -225,8 +224,7 @@ mod tests {
#[test]
fn test_safe() {
let csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com");
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::HEAD)
@ -237,8 +235,7 @@ mod tests {
#[test]
fn test_csrf() {
let csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com");
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::POST)
@ -249,11 +246,12 @@ mod tests {
#[test]
fn test_referer() {
let csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com");
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param")
.method(Method::POST)
let mut req = TestRequest::with_header(
"Referer",
"https://www.example.com/some/path?query=param",
).method(Method::POST)
.finish();
assert!(csrf.start(&mut req).is_ok());
@ -261,8 +259,7 @@ mod tests {
#[test]
fn test_upgrade() {
let strict_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com");
let strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let lax_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com")

View File

@ -1,11 +1,11 @@
//! Default response headers
use http::{HeaderMap, HttpTryFrom};
use http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use http::{HeaderMap, HttpTryFrom};
use error::Result;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Response, Middleware};
use middleware::{Middleware, Response};
/// `Middleware` for setting default response headers.
///
@ -27,14 +27,17 @@ use middleware::{Response, Middleware};
/// .finish();
/// }
/// ```
pub struct DefaultHeaders{
pub struct DefaultHeaders {
ct: bool,
headers: HeaderMap,
}
impl Default for DefaultHeaders {
fn default() -> Self {
DefaultHeaders{ct: false, headers: HeaderMap::new()}
DefaultHeaders {
ct: false,
headers: HeaderMap::new(),
}
}
}
@ -48,15 +51,16 @@ impl DefaultHeaders {
#[inline]
#[cfg_attr(feature = "cargo-clippy", allow(match_wild_err_arm))]
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>
where
HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>,
{
match HeaderName::try_from(key) {
Ok(key) => {
match HeaderValue::try_from(value) {
Ok(value) => { self.headers.append(key, value); }
Err(_) => panic!("Can not create header value"),
Ok(key) => match HeaderValue::try_from(value) {
Ok(value) => {
self.headers.append(key, value);
}
Err(_) => panic!("Can not create header value"),
},
Err(_) => panic!("Can not create header name"),
}
@ -71,8 +75,9 @@ impl DefaultHeaders {
}
impl<S> Middleware<S> for DefaultHeaders {
fn response(&self, _: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> {
fn response(
&self, _: &mut HttpRequest<S>, mut resp: HttpResponse
) -> Result<Response> {
for (key, value) in self.headers.iter() {
if !resp.headers().contains_key(key) {
resp.headers_mut().insert(key, value.clone());
@ -81,7 +86,9 @@ impl<S> Middleware<S> for DefaultHeaders {
// default content-type
if self.ct && !resp.headers().contains_key(CONTENT_TYPE) {
resp.headers_mut().insert(
CONTENT_TYPE, HeaderValue::from_static("application/octet-stream"));
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
}
Ok(Response::Done(resp))
}
@ -94,8 +101,7 @@ mod tests {
#[test]
fn test_default_headers() {
let mw = DefaultHeaders::new()
.header(CONTENT_TYPE, "0001");
let mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001");
let mut req = HttpRequest::default();
@ -106,7 +112,9 @@ mod tests {
};
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let resp = HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish();
let resp = HttpResponse::Ok()
.header(CONTENT_TYPE, "0002")
.finish();
let resp = match mw.response(&mut req, resp) {
Ok(Response::Done(resp)) => resp,
_ => panic!(),

View File

@ -6,14 +6,13 @@ use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Response};
type ErrorHandler<S> = Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response>;
/// `Middleware` for allowing custom handlers for responses.
///
/// You can use `ErrorHandlers::handler()` method to register a custom error handler
/// for specific status code. You can modify existing response or create completly new
/// one.
/// You can use `ErrorHandlers::handler()` method to register a custom error
/// handler for specific status code. You can modify existing response or
/// create completly new one.
///
/// ## Example
///
@ -53,7 +52,6 @@ impl<S> Default for ErrorHandlers<S> {
}
impl<S> ErrorHandlers<S> {
/// Construct new `ErrorHandlers` instance
pub fn new() -> Self {
ErrorHandlers::default()
@ -61,7 +59,8 @@ impl<S> ErrorHandlers<S> {
/// Register error handler for specified status code
pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
where F: Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response> + 'static
where
F: Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response> + 'static,
{
self.handlers.insert(status, Box::new(handler));
self
@ -69,8 +68,9 @@ impl<S> ErrorHandlers<S> {
}
impl<S: 'static> Middleware<S> for ErrorHandlers<S> {
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<Response> {
if let Some(handler) = self.handlers.get(&resp.status()) {
handler(req, resp)
} else {
@ -90,11 +90,11 @@ mod tests {
builder.header(CONTENT_TYPE, "0001");
Ok(Response::Done(builder.into()))
}
#[test]
fn test_handler() {
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mw =
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mut req = HttpRequest::default();
let resp = HttpResponse::InternalServerError().finish();

View File

@ -4,14 +4,14 @@ use std::fmt;
use std::fmt::{Display, Formatter};
use libc;
use time;
use regex::Regex;
use time;
use error::Result;
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Started, Finished};
use middleware::{Finished, Middleware, Started};
/// `Middleware` for logging request and response info to the terminal.
/// `Logger` middleware uses standard log crate to log information. You should
@ -21,7 +21,8 @@ use middleware::{Middleware, Started, Finished};
/// ## Usage
///
/// Create `Logger` middleware with the specified `format`.
/// Default `Logger` could be created with `default` method, it uses the default format:
/// Default `Logger` could be created with `default` method, it uses the
/// default format:
///
/// ```ignore
/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
@ -59,7 +60,8 @@ use middleware::{Middleware, Started, Finished};
///
/// `%b` Size of response in bytes, including HTTP headers
///
/// `%T` Time taken to serve the request, in seconds with floating fraction in .06f format
/// `%T` Time taken to serve the request, in seconds with floating fraction in
/// .06f format
///
/// `%D` Time taken to serve the request, in milliseconds
///
@ -76,7 +78,9 @@ pub struct Logger {
impl Logger {
/// Create `Logger` middleware with the specified `format`.
pub fn new(format: &str) -> Logger {
Logger { format: Format::new(format) }
Logger {
format: Format::new(format),
}
}
}
@ -87,14 +91,15 @@ impl Default for Logger {
/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
/// ```
fn default() -> Logger {
Logger { format: Format::default() }
Logger {
format: Format::default(),
}
}
}
struct StartTime(time::Tm);
impl Logger {
fn log<S>(&self, req: &mut HttpRequest<S>, resp: &HttpResponse) {
let entry_time = req.extensions().get::<StartTime>().unwrap().0;
@ -109,7 +114,6 @@ impl Logger {
}
impl<S> Middleware<S> for Logger {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
req.extensions().insert(StartTime(time::now()));
Ok(Started::Done)
@ -153,29 +157,26 @@ impl Format {
idx = m.end();
if let Some(key) = cap.get(2) {
results.push(
match cap.get(3).unwrap().as_str() {
"i" => FormatText::RequestHeader(key.as_str().to_owned()),
"o" => FormatText::ResponseHeader(key.as_str().to_owned()),
"e" => FormatText::EnvironHeader(key.as_str().to_owned()),
_ => unreachable!(),
})
results.push(match cap.get(3).unwrap().as_str() {
"i" => FormatText::RequestHeader(key.as_str().to_owned()),
"o" => FormatText::ResponseHeader(key.as_str().to_owned()),
"e" => FormatText::EnvironHeader(key.as_str().to_owned()),
_ => unreachable!(),
})
} else {
let m = cap.get(1).unwrap();
results.push(
match m.as_str() {
"%" => FormatText::Percent,
"a" => FormatText::RemoteAddr,
"t" => FormatText::RequestTime,
"P" => FormatText::Pid,
"r" => FormatText::RequestLine,
"s" => FormatText::ResponseStatus,
"b" => FormatText::ResponseSize,
"T" => FormatText::Time,
"D" => FormatText::TimeMillis,
_ => FormatText::Str(m.as_str().to_owned()),
}
);
results.push(match m.as_str() {
"%" => FormatText::Percent,
"a" => FormatText::RemoteAddr,
"t" => FormatText::RequestTime,
"P" => FormatText::Pid,
"r" => FormatText::RequestLine,
"s" => FormatText::ResponseStatus,
"b" => FormatText::ResponseSize,
"T" => FormatText::Time,
"D" => FormatText::TimeMillis,
_ => FormatText::Str(m.as_str().to_owned()),
});
}
}
if idx != s.len() {
@ -207,12 +208,10 @@ pub enum FormatText {
}
impl FormatText {
fn render<S>(&self, fmt: &mut Formatter,
req: &HttpRequest<S>,
resp: &HttpResponse,
entry_time: time::Tm) -> Result<(), fmt::Error>
{
fn render<S>(
&self, fmt: &mut Formatter, req: &HttpRequest<S>, resp: &HttpResponse,
entry_time: time::Tm,
) -> Result<(), fmt::Error> {
match *self {
FormatText::Str(ref string) => fmt.write_str(string),
FormatText::Percent => "%".fmt(fmt),
@ -220,26 +219,33 @@ impl FormatText {
if req.query_string().is_empty() {
fmt.write_fmt(format_args!(
"{} {} {:?}",
req.method(), req.path(), req.version()))
req.method(),
req.path(),
req.version()
))
} else {
fmt.write_fmt(format_args!(
"{} {}?{} {:?}",
req.method(), req.path(), req.query_string(), req.version()))
req.method(),
req.path(),
req.query_string(),
req.version()
))
}
},
}
FormatText::ResponseStatus => resp.status().as_u16().fmt(fmt),
FormatText::ResponseSize => resp.response_size().fmt(fmt),
FormatText::Pid => unsafe{libc::getpid().fmt(fmt)},
FormatText::Pid => unsafe { libc::getpid().fmt(fmt) },
FormatText::Time => {
let rt = time::now() - entry_time;
let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000_000.0;
fmt.write_fmt(format_args!("{:.6}", rt))
},
}
FormatText::TimeMillis => {
let rt = time::now() - entry_time;
let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000.0;
fmt.write_fmt(format_args!("{:.6}", rt))
},
}
FormatText::RemoteAddr => {
if let Some(remote) = req.connection_info().remote() {
return remote.fmt(fmt);
@ -247,14 +253,17 @@ impl FormatText {
"-".fmt(fmt)
}
}
FormatText::RequestTime => {
entry_time.strftime("[%d/%b/%Y:%H:%M:%S %z]")
.unwrap()
.fmt(fmt)
}
FormatText::RequestTime => entry_time
.strftime("[%d/%b/%Y:%H:%M:%S %z]")
.unwrap()
.fmt(fmt),
FormatText::RequestHeader(ref name) => {
let s = if let Some(val) = req.headers().get(name) {
if let Ok(s) = val.to_str() { s } else { "-" }
if let Ok(s) = val.to_str() {
s
} else {
"-"
}
} else {
"-"
};
@ -262,7 +271,11 @@ impl FormatText {
}
FormatText::ResponseHeader(ref name) => {
let s = if let Some(val) = resp.headers().get(name) {
if let Ok(s) = val.to_str() { s } else { "-" }
if let Ok(s) = val.to_str() {
s
} else {
"-"
}
} else {
"-"
};
@ -279,8 +292,7 @@ impl FormatText {
}
}
pub(crate) struct FormatDisplay<'a>(
&'a Fn(&mut Formatter) -> Result<(), fmt::Error>);
pub(crate) struct FormatDisplay<'a>(&'a Fn(&mut Formatter) -> Result<(), fmt::Error>);
impl<'a> fmt::Display for FormatDisplay<'a> {
fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
@ -291,19 +303,27 @@ impl<'a> fmt::Display for FormatDisplay<'a> {
#[cfg(test)]
mod tests {
use super::*;
use http::header::{self, HeaderMap};
use http::{Method, StatusCode, Uri, Version};
use std::str::FromStr;
use time;
use http::{Method, Version, StatusCode, Uri};
use http::header::{self, HeaderMap};
#[test]
fn test_logger() {
let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB"));
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_static("ACTIX-WEB"),
);
let mut req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None);
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let resp = HttpResponse::build(StatusCode::OK)
.header("X-Test", "ttt")
.force_close()
@ -333,10 +353,20 @@ mod tests {
let format = Format::default();
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB"));
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_static("ACTIX-WEB"),
);
let req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None);
let resp = HttpResponse::build(StatusCode::OK).force_close().finish();
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let resp = HttpResponse::build(StatusCode::OK)
.force_close()
.finish();
let entry_time = time::now();
let render = |fmt: &mut Formatter| {
@ -351,9 +381,15 @@ mod tests {
assert!(s.contains("ACTIX-WEB"));
let req = HttpRequest::new(
Method::GET, Uri::from_str("/?test").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
let resp = HttpResponse::build(StatusCode::OK).force_close().finish();
Method::GET,
Uri::from_str("/?test").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
let resp = HttpResponse::build(StatusCode::OK)
.force_close()
.finish();
let entry_time = time::now();
let render = |fmt: &mut Formatter| {

View File

@ -7,19 +7,19 @@ use httpresponse::HttpResponse;
mod logger;
#[cfg(feature = "session")]
mod session;
mod defaultheaders;
mod errhandlers;
pub mod cors;
pub mod csrf;
pub use self::logger::Logger;
pub use self::errhandlers::ErrorHandlers;
mod defaultheaders;
mod errhandlers;
#[cfg(feature = "session")]
mod session;
pub use self::defaultheaders::DefaultHeaders;
pub use self::errhandlers::ErrorHandlers;
pub use self::logger::Logger;
#[cfg(feature = "session")]
pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage,
CookieSessionError, CookieSessionBackend};
pub use self::session::{CookieSessionBackend, CookieSessionError, RequestSession,
Session, SessionBackend, SessionImpl, SessionStorage};
/// Middleware start result
pub enum Started {
@ -29,7 +29,7 @@ pub enum Started {
/// handler execution halts.
Response(HttpResponse),
/// Execution completed, runs future to completion.
Future(Box<Future<Item=Option<HttpResponse>, Error=Error>>),
Future(Box<Future<Item = Option<HttpResponse>, Error = Error>>),
}
/// Middleware execution result
@ -37,7 +37,7 @@ pub enum Response {
/// New http response got generated
Done(HttpResponse),
/// Result is a future that resolves to a new http response
Future(Box<Future<Item=HttpResponse, Error=Error>>),
Future(Box<Future<Item = HttpResponse, Error = Error>>),
}
/// Middleware finish result
@ -45,13 +45,12 @@ pub enum Finished {
/// Execution completed
Done,
/// Execution completed, but run future to completion
Future(Box<Future<Item=(), Error=Error>>),
Future(Box<Future<Item = (), Error = Error>>),
}
/// Middleware definition
#[allow(unused_variables)]
pub trait Middleware<S>: 'static {
/// Method is called when request is ready. It may return
/// future, which should resolve before next middleware get called.
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
@ -60,7 +59,9 @@ pub trait Middleware<S>: 'static {
/// Method is called when handler returns response,
/// but before sending http message to peer.
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<Response> {
Ok(Response::Done(resp))
}

View File

@ -1,21 +1,21 @@
use std::collections::HashMap;
use std::marker::PhantomData;
use std::rc::Rc;
use std::sync::Arc;
use std::marker::PhantomData;
use std::collections::HashMap;
use cookie::{Cookie, CookieJar, Key};
use futures::Future;
use futures::future::{FutureResult, err as FutErr, ok as FutOk};
use http::header::{self, HeaderValue};
use serde::{Deserialize, Serialize};
use serde_json;
use serde_json::error::Error as JsonError;
use serde::{Serialize, Deserialize};
use http::header::{self, HeaderValue};
use time::Duration;
use cookie::{CookieJar, Cookie, Key};
use futures::Future;
use futures::future::{FutureResult, ok as FutOk, err as FutErr};
use error::{Result, Error, ResponseError};
use error::{Error, ResponseError, Result};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Started, Response};
use middleware::{Middleware, Response, Started};
/// The helper trait to obtain your session data from a request.
///
@ -40,14 +40,13 @@ pub trait RequestSession {
}
impl<S> RequestSession for HttpRequest<S> {
fn session(&mut self) -> Session {
if let Some(s_impl) = self.extensions().get_mut::<Arc<SessionImplBox>>() {
if let Some(s) = Arc::get_mut(s_impl) {
return Session(s.0.as_mut())
return Session(s.0.as_mut());
}
}
Session(unsafe{&mut DUMMY})
Session(unsafe { &mut DUMMY })
}
}
@ -76,7 +75,6 @@ impl<S> RequestSession for HttpRequest<S> {
pub struct Session<'a>(&'a mut SessionImpl);
impl<'a> Session<'a> {
/// Get a `value` from the session.
pub fn get<T: Deserialize<'a>>(&'a self, key: &str) -> Result<Option<T>> {
if let Some(s) = self.0.get(key) {
@ -136,24 +134,25 @@ impl<S, T: SessionBackend<S>> SessionStorage<T, S> {
}
impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
let mut req = req.clone();
let fut = self.0.from_request(&mut req)
.then(move |res| {
match res {
Ok(sess) => {
req.extensions().insert(Arc::new(SessionImplBox(Box::new(sess))));
FutOk(None)
},
Err(err) => FutErr(err)
let fut = self.0
.from_request(&mut req)
.then(move |res| match res {
Ok(sess) => {
req.extensions()
.insert(Arc::new(SessionImplBox(Box::new(sess))));
FutOk(None)
}
Err(err) => FutErr(err),
});
Ok(Started::Future(Box::new(fut)))
}
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<Response> {
if let Some(s_box) = req.extensions().remove::<Arc<SessionImplBox>>() {
s_box.0.write(resp)
} else {
@ -165,7 +164,6 @@ impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
/// A simple key-value storage interface that is internally used by `Session`.
#[doc(hidden)]
pub trait SessionImpl: 'static {
fn get(&self, key: &str) -> Option<&str>;
fn set(&mut self, key: &str, value: String);
@ -182,7 +180,7 @@ pub trait SessionImpl: 'static {
#[doc(hidden)]
pub trait SessionBackend<S>: Sized + 'static {
type Session: SessionImpl;
type ReadFuture: Future<Item=Self::Session, Error=Error>;
type ReadFuture: Future<Item = Self::Session, Error = Error>;
/// Parse the session from request and load data from a storage backend.
fn from_request(&self, request: &mut HttpRequest<S>) -> Self::ReadFuture;
@ -194,8 +192,9 @@ struct DummySessionImpl;
static mut DUMMY: DummySessionImpl = DummySessionImpl;
impl SessionImpl for DummySessionImpl {
fn get(&self, _: &str) -> Option<&str> { None }
fn get(&self, _: &str) -> Option<&str> {
None
}
fn set(&mut self, _: &str, _: String) {}
fn remove(&mut self, _: &str) {}
fn clear(&mut self) {}
@ -215,17 +214,16 @@ pub struct CookieSession {
#[derive(Fail, Debug)]
pub enum CookieSessionError {
/// Size of the serialized session is greater than 4000 bytes.
#[fail(display="Size of the serialized session is greater than 4000 bytes.")]
#[fail(display = "Size of the serialized session is greater than 4000 bytes.")]
Overflow,
/// Fail to serialize session.
#[fail(display="Fail to serialize session")]
#[fail(display = "Fail to serialize session")]
Serialize(JsonError),
}
impl ResponseError for CookieSessionError {}
impl SessionImpl for CookieSession {
fn get(&self, key: &str) -> Option<&str> {
if let Some(s) = self.state.get(key) {
Some(s)
@ -259,7 +257,7 @@ impl SessionImpl for CookieSession {
enum CookieSecurity {
Signed,
Private
Private,
}
struct CookieSessionInner {
@ -273,7 +271,6 @@ struct CookieSessionInner {
}
impl CookieSessionInner {
fn new(key: &[u8], security: CookieSecurity) -> CookieSessionInner {
CookieSessionInner {
security,
@ -286,11 +283,13 @@ impl CookieSessionInner {
}
}
fn set_cookie(&self, resp: &mut HttpResponse, state: &HashMap<String, String>) -> Result<()> {
let value = serde_json::to_string(&state)
.map_err(CookieSessionError::Serialize)?;
fn set_cookie(
&self, resp: &mut HttpResponse, state: &HashMap<String, String>
) -> Result<()> {
let value =
serde_json::to_string(&state).map_err(CookieSessionError::Serialize)?;
if value.len() > 4064 {
return Err(CookieSessionError::Overflow.into())
return Err(CookieSessionError::Overflow.into());
}
let mut cookie = Cookie::new(self.name.clone(), value);
@ -330,7 +329,9 @@ impl CookieSessionInner {
let cookie_opt = match self.security {
CookieSecurity::Signed => jar.signed(&self.key).get(&self.name),
CookieSecurity::Private => jar.private(&self.key).get(&self.name),
CookieSecurity::Private => {
jar.private(&self.key).get(&self.name)
}
};
if let Some(cookie) = cookie_opt {
if let Ok(val) = serde_json::from_str(cookie.value()) {
@ -347,20 +348,24 @@ impl CookieSessionInner {
/// Use cookies for session storage.
///
/// `CookieSessionBackend` creates sessions which are limited to storing
/// fewer than 4000 bytes of data (as the payload must fit into a single cookie).
/// An Internal Server Error is generated if the session contains more than 4000 bytes.
/// fewer than 4000 bytes of data (as the payload must fit into a single
/// cookie). An Internal Server Error is generated if the session contains more
/// than 4000 bytes.
///
/// A cookie may have a security policy of *signed* or *private*. Each has a respective `CookieSessionBackend` constructor.
/// A cookie may have a security policy of *signed* or *private*. Each has a
/// respective `CookieSessionBackend` constructor.
///
/// A *signed* cookie is stored on the client as plaintext alongside
/// a signature such that the cookie may be viewed but not modified by the client.
/// a signature such that the cookie may be viewed but not modified by the
/// client.
///
/// A *private* cookie is stored on the client as encrypted text
/// such that it may neither be viewed nor modified by the client.
///
/// The constructors take a key as an argument.
/// This is the private key for cookie session - when this value is changed, all session data is lost.
/// The constructors will panic if the key is less than 32 bytes in length.
/// This is the private key for cookie session - when this value is changed,
/// all session data is lost. The constructors will panic if the key is less
/// than 32 bytes in length.
///
///
/// # Example
@ -380,21 +385,24 @@ impl CookieSessionInner {
pub struct CookieSessionBackend(Rc<CookieSessionInner>);
impl CookieSessionBackend {
/// Construct new *signed* `CookieSessionBackend` instance.
///
/// Panics if key length is less than 32 bytes.
pub fn signed(key: &[u8]) -> CookieSessionBackend {
CookieSessionBackend(
Rc::new(CookieSessionInner::new(key, CookieSecurity::Signed)))
CookieSessionBackend(Rc::new(CookieSessionInner::new(
key,
CookieSecurity::Signed,
)))
}
/// Construct new *private* `CookieSessionBackend` instance.
///
/// Panics if key length is less than 32 bytes.
pub fn private(key: &[u8]) -> CookieSessionBackend {
CookieSessionBackend(
Rc::new(CookieSessionInner::new(key, CookieSecurity::Private)))
CookieSessionBackend(Rc::new(CookieSessionInner::new(
key,
CookieSecurity::Private,
)))
}
/// Sets the `path` field in the session cookie being built.
@ -432,17 +440,15 @@ impl CookieSessionBackend {
}
impl<S> SessionBackend<S> for CookieSessionBackend {
type Session = CookieSession;
type ReadFuture = FutureResult<CookieSession, Error>;
fn from_request(&self, req: &mut HttpRequest<S>) -> Self::ReadFuture {
let state = self.0.load(req);
FutOk(
CookieSession {
changed: false,
inner: Rc::clone(&self.0),
state,
})
FutOk(CookieSession {
changed: false,
inner: Rc::clone(&self.0),
state,
})
}
}

View File

@ -1,18 +1,18 @@
//! Multipart requests support
use std::{cmp, fmt};
use std::rc::Rc;
use std::cell::RefCell;
use std::marker::PhantomData;
use std::rc::Rc;
use std::{cmp, fmt};
use mime;
use httparse;
use bytes::Bytes;
use futures::task::{current as current_task, Task};
use futures::{Async, Poll, Stream};
use http::HttpTryFrom;
use http::header::{self, HeaderMap, HeaderName, HeaderValue};
use futures::{Async, Stream, Poll};
use futures::task::{Task, current as current_task};
use httparse;
use mime;
use error::{ParseError, PayloadError, MultipartError};
use error::{MultipartError, ParseError, PayloadError};
use payload::PayloadHelper;
const MAX_HEADERS: usize = 32;
@ -85,33 +85,36 @@ impl Multipart<()> {
}
}
impl<S> Multipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> Multipart<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create multipart instance for boundary.
pub fn new(boundary: Result<String, MultipartError>, stream: S) -> Multipart<S> {
match boundary {
Ok(boundary) => Multipart {
error: None,
safety: Safety::new(),
inner: Some(Rc::new(RefCell::new(
InnerMultipart {
boundary,
payload: PayloadRef::new(PayloadHelper::new(stream)),
state: InnerState::FirstBoundary,
item: InnerMultipartItem::None,
})))
inner: Some(Rc::new(RefCell::new(InnerMultipart {
boundary,
payload: PayloadRef::new(PayloadHelper::new(stream)),
state: InnerState::FirstBoundary,
item: InnerMultipartItem::None,
}))),
},
Err(err) => Multipart {
error: Some(err),
safety: Safety::new(),
inner: None,
},
Err(err) =>
Multipart {
error: Some(err),
safety: Safety::new(),
inner: None,
}
}
}
}
impl<S> Stream for Multipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> Stream for Multipart<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = MultipartItem<S>;
type Error = MultipartError;
@ -119,17 +122,22 @@ impl<S> Stream for Multipart<S> where S: Stream<Item=Bytes, Error=PayloadError>
if let Some(err) = self.error.take() {
Err(err)
} else if self.safety.current() {
self.inner.as_mut().unwrap().borrow_mut().poll(&self.safety)
self.inner
.as_mut()
.unwrap()
.borrow_mut()
.poll(&self.safety)
} else {
Ok(Async::NotReady)
}
}
}
impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
fn read_headers(payload: &mut PayloadHelper<S>) -> Poll<HeaderMap, MultipartError>
{
impl<S> InnerMultipart<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
fn read_headers(payload: &mut PayloadHelper<S>) -> Poll<HeaderMap, MultipartError> {
match payload.read_until(b"\r\n\r\n")? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete),
@ -144,10 +152,10 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
if let Ok(value) = HeaderValue::try_from(h.value) {
headers.append(name, value);
} else {
return Err(ParseError::Header.into())
return Err(ParseError::Header.into());
}
} else {
return Err(ParseError::Header.into())
return Err(ParseError::Header.into());
}
}
Ok(Async::Ready(headers))
@ -159,23 +167,21 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
}
}
fn read_boundary(payload: &mut PayloadHelper<S>, boundary: &str)
-> Poll<bool, MultipartError>
{
fn read_boundary(
payload: &mut PayloadHelper<S>, boundary: &str
) -> Poll<bool, MultipartError> {
// TODO: need to read epilogue
match payload.readline()? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete),
Async::Ready(Some(chunk)) => {
if chunk.len() == boundary.len() + 4 &&
&chunk[..2] == b"--" &&
&chunk[2..boundary.len()+2] == boundary.as_bytes()
if chunk.len() == boundary.len() + 4 && &chunk[..2] == b"--"
&& &chunk[2..boundary.len() + 2] == boundary.as_bytes()
{
Ok(Async::Ready(false))
} else if chunk.len() == boundary.len() + 6 &&
&chunk[..2] == b"--" &&
&chunk[2..boundary.len()+2] == boundary.as_bytes() &&
&chunk[boundary.len()+2..boundary.len()+4] == b"--"
} else if chunk.len() == boundary.len() + 6 && &chunk[..2] == b"--"
&& &chunk[2..boundary.len() + 2] == boundary.as_bytes()
&& &chunk[boundary.len() + 2..boundary.len() + 4] == b"--"
{
Ok(Async::Ready(true))
} else {
@ -185,9 +191,9 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
}
}
fn skip_until_boundary(payload: &mut PayloadHelper<S>, boundary: &str)
-> Poll<bool, MultipartError>
{
fn skip_until_boundary(
payload: &mut PayloadHelper<S>, boundary: &str
) -> Poll<bool, MultipartError> {
let mut eof = false;
loop {
match payload.readline()? {
@ -197,22 +203,25 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
//% (self._boundary))
}
if chunk.len() < boundary.len() {
continue
continue;
}
if &chunk[..2] == b"--" && &chunk[2..chunk.len()-2] == boundary.as_bytes() {
if &chunk[..2] == b"--"
&& &chunk[2..chunk.len() - 2] == boundary.as_bytes()
{
break;
} else {
if chunk.len() < boundary.len() + 2{
continue
if chunk.len() < boundary.len() + 2 {
continue;
}
let b: &[u8] = boundary.as_ref();
if &chunk[..boundary.len()] == b &&
&chunk[boundary.len()..boundary.len()+2] == b"--" {
eof = true;
break;
}
if &chunk[..boundary.len()] == b
&& &chunk[boundary.len()..boundary.len() + 2] == b"--"
{
eof = true;
break;
}
}
},
}
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(None) => return Err(MultipartError::Incomplete),
}
@ -220,7 +229,9 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Ok(Async::Ready(eof))
}
fn poll(&mut self, safety: &Safety) -> Poll<Option<MultipartItem<S>>, MultipartError> {
fn poll(
&mut self, safety: &Safety
) -> Poll<Option<MultipartItem<S>>, MultipartError> {
if self.state == InnerState::Eof {
Ok(Async::Ready(None))
} else {
@ -236,14 +247,14 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Async::Ready(Some(_)) => continue,
Async::Ready(None) => true,
}
},
}
InnerMultipartItem::Multipart(ref mut multipart) => {
match multipart.borrow_mut().poll(safety)? {
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(_)) => continue,
Async::Ready(None) => true,
}
},
}
_ => false,
};
if stop {
@ -259,7 +270,10 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
match self.state {
// read until first boundary
InnerState::FirstBoundary => {
match InnerMultipart::skip_until_boundary(payload, &self.boundary)? {
match InnerMultipart::skip_until_boundary(
payload,
&self.boundary,
)? {
Async::Ready(eof) => {
if eof {
self.state = InnerState::Eof;
@ -267,10 +281,10 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} else {
self.state = InnerState::Headers;
}
},
}
Async::NotReady => return Ok(Async::NotReady),
}
},
}
// read boundary
InnerState::Boundary => {
match InnerMultipart::read_boundary(payload, &self.boundary)? {
@ -290,18 +304,19 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// read field headers for next field
if self.state == InnerState::Headers {
if let Async::Ready(headers) = InnerMultipart::read_headers(payload)? {
if let Async::Ready(headers) = InnerMultipart::read_headers(payload)?
{
self.state = InnerState::Boundary;
headers
} else {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
} else {
unreachable!()
}
} else {
debug!("NotReady: field is in flight");
return Ok(Async::NotReady)
return Ok(Async::NotReady);
};
// content type
@ -319,32 +334,37 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// nested multipart stream
if mt.type_() == mime::MULTIPART {
let inner = if let Some(boundary) = mt.get_param(mime::BOUNDARY) {
Rc::new(RefCell::new(
InnerMultipart {
payload: self.payload.clone(),
boundary: boundary.as_str().to_owned(),
state: InnerState::FirstBoundary,
item: InnerMultipartItem::None,
}))
Rc::new(RefCell::new(InnerMultipart {
payload: self.payload.clone(),
boundary: boundary.as_str().to_owned(),
state: InnerState::FirstBoundary,
item: InnerMultipartItem::None,
}))
} else {
return Err(MultipartError::Boundary)
return Err(MultipartError::Boundary);
};
self.item = InnerMultipartItem::Multipart(Rc::clone(&inner));
Ok(Async::Ready(Some(
MultipartItem::Nested(
Multipart{safety: safety.clone(),
error: None,
inner: Some(inner)}))))
Ok(Async::Ready(Some(MultipartItem::Nested(Multipart {
safety: safety.clone(),
error: None,
inner: Some(inner),
}))))
} else {
let field = Rc::new(RefCell::new(InnerField::new(
self.payload.clone(), self.boundary.clone(), &headers)?));
self.payload.clone(),
self.boundary.clone(),
&headers,
)?));
self.item = InnerMultipartItem::Field(Rc::clone(&field));
Ok(Async::Ready(Some(
MultipartItem::Field(
Field::new(safety.clone(), headers, mt, field)))))
Ok(Async::Ready(Some(MultipartItem::Field(Field::new(
safety.clone(),
headers,
mt,
field,
)))))
}
}
}
@ -365,11 +385,20 @@ pub struct Field<S> {
safety: Safety,
}
impl<S> Field<S> where S: Stream<Item=Bytes, Error=PayloadError> {
fn new(safety: Safety, headers: HeaderMap,
ct: mime::Mime, inner: Rc<RefCell<InnerField<S>>>) -> Self {
Field {ct, headers, inner, safety}
impl<S> Field<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
fn new(
safety: Safety, headers: HeaderMap, ct: mime::Mime,
inner: Rc<RefCell<InnerField<S>>>,
) -> Self {
Field {
ct,
headers,
inner,
safety,
}
}
pub fn headers(&self) -> &HeaderMap {
@ -381,7 +410,10 @@ impl<S> Field<S> where S: Stream<Item=Bytes, Error=PayloadError> {
}
}
impl<S> Stream for Field<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> Stream for Field<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = Bytes;
type Error = MultipartError;
@ -413,20 +445,22 @@ struct InnerField<S> {
length: Option<u64>,
}
impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
fn new(payload: PayloadRef<S>, boundary: String, headers: &HeaderMap)
-> Result<InnerField<S>, PayloadError>
{
impl<S> InnerField<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
fn new(
payload: PayloadRef<S>, boundary: String, headers: &HeaderMap
) -> Result<InnerField<S>, PayloadError> {
let len = if let Some(len) = headers.get(header::CONTENT_LENGTH) {
if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() {
Some(len)
} else {
return Err(PayloadError::Incomplete)
return Err(PayloadError::Incomplete);
}
} else {
return Err(PayloadError::Incomplete)
return Err(PayloadError::Incomplete);
}
} else {
None
@ -436,14 +470,15 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
boundary,
payload: Some(payload),
eof: false,
length: len })
length: len,
})
}
/// Reads body part content chunk of the specified size.
/// The body part must has `Content-Length` header with proper value.
fn read_len(payload: &mut PayloadHelper<S>, size: &mut u64)
-> Poll<Option<Bytes>, MultipartError>
{
fn read_len(
payload: &mut PayloadHelper<S>, size: &mut u64
) -> Poll<Option<Bytes>, MultipartError> {
if *size == 0 {
Ok(Async::Ready(None))
} else {
@ -458,17 +493,17 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
payload.unread_data(chunk);
}
Ok(Async::Ready(Some(ch)))
},
Err(err) => Err(err.into())
}
Err(err) => Err(err.into()),
}
}
}
/// Reads content chunk of body part with unknown length.
/// The `Content-Length` header for body part is not necessary.
fn read_stream(payload: &mut PayloadHelper<S>, boundary: &str)
-> Poll<Option<Bytes>, MultipartError>
{
fn read_stream(
payload: &mut PayloadHelper<S>, boundary: &str
) -> Poll<Option<Bytes>, MultipartError> {
match payload.read_until(b"\r")? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete),
@ -479,8 +514,8 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete),
Async::Ready(Some(chunk)) => {
if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--" &&
&chunk[4..] == boundary.as_bytes()
if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--"
&& &chunk[4..] == boundary.as_bytes()
{
payload.unread_data(chunk);
Ok(Async::Ready(None))
@ -501,7 +536,7 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
fn poll(&mut self, s: &Safety) -> Poll<Option<Bytes>, MultipartError> {
if self.payload.is_none() {
return Ok(Async::Ready(None))
return Ok(Async::Ready(None));
}
let result = if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) {
@ -543,7 +578,10 @@ struct PayloadRef<S> {
payload: Rc<PayloadHelper<S>>,
}
impl<S> PayloadRef<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> PayloadRef<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
fn new(payload: PayloadHelper<S>) -> PayloadRef<S> {
PayloadRef {
payload: Rc::new(payload),
@ -551,11 +589,12 @@ impl<S> PayloadRef<S> where S: Stream<Item=Bytes, Error=PayloadError> {
}
fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut PayloadHelper<S>>
where 'a: 'b
where
'a: 'b,
{
if s.current() {
let payload: &mut PayloadHelper<S> = unsafe {
&mut *(self.payload.as_ref() as *const _ as *mut _)};
let payload: &mut PayloadHelper<S> =
unsafe { &mut *(self.payload.as_ref() as *const _ as *mut _) };
Some(payload)
} else {
None
@ -571,8 +610,9 @@ impl<S> Clone for PayloadRef<S> {
}
}
/// Counter. It tracks of number of clones of payloads and give access to payload only
/// to top most task panics if Safety get destroyed and it not top most task.
/// Counter. It tracks of number of clones of payloads and give access to
/// payload only to top most task panics if Safety get destroyed and it not top
/// most task.
#[derive(Debug)]
struct Safety {
task: Option<Task>,
@ -593,7 +633,6 @@ impl Safety {
fn current(&self) -> bool {
Rc::strong_count(&self.payload) == self.level
}
}
impl Clone for Safety {
@ -624,8 +663,8 @@ mod tests {
use super::*;
use bytes::Bytes;
use futures::future::{lazy, result};
use tokio_core::reactor::Core;
use payload::{Payload, PayloadWriter};
use tokio_core::reactor::Core;
#[test]
fn test_boundary() {
@ -636,8 +675,10 @@ mod tests {
}
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE,
header::HeaderValue::from_static("test"));
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("test"),
);
match Multipart::boundary(&headers) {
Err(MultipartError::ParseContentType) => (),
@ -647,7 +688,8 @@ mod tests {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("multipart/mixed"));
header::HeaderValue::from_static("multipart/mixed"),
);
match Multipart::boundary(&headers) {
Err(MultipartError::Boundary) => (),
_ => unreachable!("should not happen"),
@ -657,18 +699,24 @@ mod tests {
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static(
"multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\""));
"multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"",
),
);
assert_eq!(Multipart::boundary(&headers).unwrap(),
"5c02368e880e436dab70ed54e1c58209");
assert_eq!(
Multipart::boundary(&headers).unwrap(),
"5c02368e880e436dab70ed54e1c58209"
);
}
#[test]
fn test_multipart() {
Core::new().unwrap().run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
Core::new()
.unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let bytes = Bytes::from(
let bytes = Bytes::from(
"testasdadsad\r\n\
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
@ -677,63 +725,64 @@ mod tests {
Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
data\r\n\
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n");
sender.feed_data(bytes);
sender.feed_data(bytes);
let mut multipart = Multipart::new(
Ok("abbc761f78ff4d7cb7573b5a23f96ef0".to_owned()), payload);
match multipart.poll() {
Ok(Async::Ready(Some(item))) => {
match item {
let mut multipart = Multipart::new(
Ok("abbc761f78ff4d7cb7573b5a23f96ef0".to_owned()),
payload,
);
match multipart.poll() {
Ok(Async::Ready(Some(item))) => match item {
MultipartItem::Field(mut field) => {
assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() {
Ok(Async::Ready(Some(chunk))) =>
assert_eq!(chunk, "test"),
_ => unreachable!()
Ok(Async::Ready(Some(chunk))) => {
assert_eq!(chunk, "test")
}
_ => unreachable!(),
}
match field.poll() {
Ok(Async::Ready(None)) => (),
_ => unreachable!()
_ => unreachable!(),
}
},
_ => unreachable!()
}
}
_ => unreachable!(),
},
_ => unreachable!(),
}
_ => unreachable!()
}
match multipart.poll() {
Ok(Async::Ready(Some(item))) => {
match item {
match multipart.poll() {
Ok(Async::Ready(Some(item))) => match item {
MultipartItem::Field(mut field) => {
assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() {
Ok(Async::Ready(Some(chunk))) =>
assert_eq!(chunk, "data"),
_ => unreachable!()
Ok(Async::Ready(Some(chunk))) => {
assert_eq!(chunk, "data")
}
_ => unreachable!(),
}
match field.poll() {
Ok(Async::Ready(None)) => (),
_ => unreachable!()
_ => unreachable!(),
}
},
_ => unreachable!()
}
}
_ => unreachable!(),
},
_ => unreachable!(),
}
_ => unreachable!()
}
match multipart.poll() {
Ok(Async::Ready(None)) => (),
_ => unreachable!()
}
match multipart.poll() {
Ok(Async::Ready(None)) => (),
_ => unreachable!(),
}
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
}

View File

@ -1,16 +1,16 @@
use std;
use std::ops::Index;
use std::path::PathBuf;
use std::str::FromStr;
use std::slice::Iter;
use std::borrow::Cow;
use http::StatusCode;
use smallvec::SmallVec;
use std;
use std::borrow::Cow;
use std::ops::Index;
use std::path::PathBuf;
use std::slice::Iter;
use std::str::FromStr;
use error::{ResponseError, UriSegmentError, InternalError};
use error::{InternalError, ResponseError, UriSegmentError};
/// A trait to abstract the idea of creating a new instance of a type from a path parameter.
/// A trait to abstract the idea of creating a new instance of a type from a
/// path parameter.
pub trait FromParam: Sized {
/// The associated error which can be returned from parsing.
type Err: ResponseError;
@ -26,7 +26,6 @@ pub trait FromParam: Sized {
pub struct Params<'a>(SmallVec<[(Cow<'a, str>, Cow<'a, str>); 3]>);
impl<'a> Params<'a> {
pub(crate) fn new() -> Params<'a> {
Params(SmallVec::new())
}
@ -36,7 +35,9 @@ impl<'a> Params<'a> {
}
pub(crate) fn add<N, V>(&mut self, name: N, value: V)
where N: Into<Cow<'a, str>>, V: Into<Cow<'a, str>>,
where
N: Into<Cow<'a, str>>,
V: Into<Cow<'a, str>>,
{
self.0.push((name.into(), value.into()));
}
@ -55,7 +56,7 @@ impl<'a> Params<'a> {
pub fn get(&'a self, key: &str) -> Option<&'a str> {
for item in self.0.iter() {
if key == item.0 {
return Some(item.1.as_ref())
return Some(item.1.as_ref());
}
}
None
@ -63,7 +64,8 @@ impl<'a> Params<'a> {
/// Get matched `FromParam` compatible parameter by name.
///
/// If keyed parameter is not available empty string is used as default value.
/// If keyed parameter is not available empty string is used as default
/// value.
///
/// ```rust
/// # extern crate actix_web;
@ -74,8 +76,7 @@ impl<'a> Params<'a> {
/// }
/// # fn main() {}
/// ```
pub fn query<T: FromParam>(&'a self, key: &str) -> Result<T, <T as FromParam>::Err>
{
pub fn query<T: FromParam>(&'a self, key: &str) -> Result<T, <T as FromParam>::Err> {
if let Some(s) = self.get(key) {
T::from_param(s)
} else {
@ -93,7 +94,8 @@ impl<'a, 'b, 'c: 'a> Index<&'b str> for &'c Params<'a> {
type Output = str;
fn index(&self, name: &'b str) -> &str {
self.get(name).expect("Value for parameter is not available")
self.get(name)
.expect("Value for parameter is not available")
}
}
@ -118,9 +120,9 @@ impl<'a, 'c: 'a> Index<usize> for &'c Params<'a> {
/// * On Windows, decoded segment contains any of: '\'
/// * Percent-encoding results in invalid UTF8.
///
/// As a result of these conditions, a `PathBuf` parsed from request path parameter is
/// safe to interpolate within, or use as a suffix of, a path without additional
/// checks.
/// As a result of these conditions, a `PathBuf` parsed from request path
/// parameter is safe to interpolate within, or use as a suffix of, a path
/// without additional checks.
impl FromParam for PathBuf {
type Err = UriSegmentError;
@ -130,19 +132,19 @@ impl FromParam for PathBuf {
if segment == ".." {
buf.pop();
} else if segment.starts_with('.') {
return Err(UriSegmentError::BadStart('.'))
return Err(UriSegmentError::BadStart('.'));
} else if segment.starts_with('*') {
return Err(UriSegmentError::BadStart('*'))
return Err(UriSegmentError::BadStart('*'));
} else if segment.ends_with(':') {
return Err(UriSegmentError::BadEnd(':'))
return Err(UriSegmentError::BadEnd(':'));
} else if segment.ends_with('>') {
return Err(UriSegmentError::BadEnd('>'))
return Err(UriSegmentError::BadEnd('>'));
} else if segment.ends_with('<') {
return Err(UriSegmentError::BadEnd('<'))
return Err(UriSegmentError::BadEnd('<'));
} else if segment.is_empty() {
continue
continue;
} else if cfg!(windows) && segment.contains('\\') {
return Err(UriSegmentError::BadChar('\\'))
return Err(UriSegmentError::BadChar('\\'));
} else {
buf.push(segment)
}
@ -162,7 +164,7 @@ macro_rules! FROM_STR {
.map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))
}
}
}
};
}
FROM_STR!(u8);
@ -192,14 +194,33 @@ mod tests {
#[test]
fn test_path_buf() {
assert_eq!(PathBuf::from_param("/test/.tt"), Err(UriSegmentError::BadStart('.')));
assert_eq!(PathBuf::from_param("/test/*tt"), Err(UriSegmentError::BadStart('*')));
assert_eq!(PathBuf::from_param("/test/tt:"), Err(UriSegmentError::BadEnd(':')));
assert_eq!(PathBuf::from_param("/test/tt<"), Err(UriSegmentError::BadEnd('<')));
assert_eq!(PathBuf::from_param("/test/tt>"), Err(UriSegmentError::BadEnd('>')));
assert_eq!(PathBuf::from_param("/seg1/seg2/"),
Ok(PathBuf::from_iter(vec!["seg1", "seg2"])));
assert_eq!(PathBuf::from_param("/seg1/../seg2/"),
Ok(PathBuf::from_iter(vec!["seg2"])));
assert_eq!(
PathBuf::from_param("/test/.tt"),
Err(UriSegmentError::BadStart('.'))
);
assert_eq!(
PathBuf::from_param("/test/*tt"),
Err(UriSegmentError::BadStart('*'))
);
assert_eq!(
PathBuf::from_param("/test/tt:"),
Err(UriSegmentError::BadEnd(':'))
);
assert_eq!(
PathBuf::from_param("/test/tt<"),
Err(UriSegmentError::BadEnd('<'))
);
assert_eq!(
PathBuf::from_param("/test/tt>"),
Err(UriSegmentError::BadEnd('>'))
);
assert_eq!(
PathBuf::from_param("/seg1/seg2/"),
Ok(PathBuf::from_iter(vec!["seg1", "seg2"]))
);
assert_eq!(
PathBuf::from_param("/seg1/../seg2/"),
Ok(PathBuf::from_iter(vec!["seg2"]))
);
}
}

View File

@ -1,18 +1,17 @@
//! Payload stream
use std::cmp;
use std::rc::{Rc, Weak};
use std::cell::RefCell;
use std::collections::VecDeque;
use bytes::{Bytes, BytesMut};
use futures::task::{current as current_task, Task};
use futures::{Async, Poll, Stream};
use futures::task::{Task, current as current_task};
use std::cell::RefCell;
use std::cmp;
use std::collections::VecDeque;
use std::rc::{Rc, Weak};
use error::PayloadError;
/// max buffer size 32k
pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
#[derive(Debug, PartialEq)]
pub(crate) enum PayloadStatus {
Read,
@ -22,9 +21,9 @@ pub(crate) enum PayloadStatus {
/// Buffered stream of bytes chunks
///
/// Payload stores chunks in a vector. First chunk can be received with `.readany()` method.
/// Payload stream is not thread safe. Payload does not notify current task when
/// new data is available.
/// Payload stores chunks in a vector. First chunk can be received with
/// `.readany()` method. Payload stream is not thread safe. Payload does not
/// notify current task when new data is available.
///
/// Payload stream can be used as `HttpResponse` body stream.
#[derive(Debug)]
@ -33,10 +32,10 @@ pub struct Payload {
}
impl Payload {
/// Create payload stream.
///
/// This method construct two objects responsible for bytes stream generation.
/// This method construct two objects responsible for bytes stream
/// generation.
///
/// * `PayloadSender` - *Sender* side of the stream
///
@ -44,13 +43,20 @@ impl Payload {
pub fn new(eof: bool) -> (PayloadSender, Payload) {
let shared = Rc::new(RefCell::new(Inner::new(eof)));
(PayloadSender{inner: Rc::downgrade(&shared)}, Payload{inner: shared})
(
PayloadSender {
inner: Rc::downgrade(&shared),
},
Payload { inner: shared },
)
}
/// Create empty payload
#[doc(hidden)]
pub fn empty() -> Payload {
Payload{inner: Rc::new(RefCell::new(Inner::new(true)))}
Payload {
inner: Rc::new(RefCell::new(Inner::new(true))),
}
}
/// Indicates EOF of payload
@ -103,13 +109,14 @@ impl Stream for Payload {
impl Clone for Payload {
fn clone(&self) -> Payload {
Payload{inner: Rc::clone(&self.inner)}
Payload {
inner: Rc::clone(&self.inner),
}
}
}
/// Payload writer interface.
pub(crate) trait PayloadWriter {
/// Set stream error.
fn set_error(&mut self, err: PayloadError);
@ -129,7 +136,6 @@ pub struct PayloadSender {
}
impl PayloadWriter for PayloadSender {
#[inline]
fn set_error(&mut self, err: PayloadError) {
if let Some(shared) = self.inner.upgrade() {
@ -186,7 +192,6 @@ struct Inner {
}
impl Inner {
fn new(eof: bool) -> Self {
Inner {
eof,
@ -292,8 +297,10 @@ pub struct PayloadHelper<S> {
stream: S,
}
impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> PayloadHelper<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
pub fn new(stream: S) -> Self {
PayloadHelper {
len: 0,
@ -309,16 +316,14 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
#[inline]
fn poll_stream(&mut self) -> Poll<bool, PayloadError> {
self.stream.poll().map(|res| {
match res {
Async::Ready(Some(data)) => {
self.len += data.len();
self.items.push_back(data);
Async::Ready(true)
},
Async::Ready(None) => Async::Ready(false),
Async::NotReady => Async::NotReady,
self.stream.poll().map(|res| match res {
Async::Ready(Some(data)) => {
self.len += data.len();
self.items.push_back(data);
Async::Ready(true)
}
Async::Ready(None) => Async::Ready(false),
Async::NotReady => Async::NotReady,
})
}
@ -373,11 +378,9 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
let buf = chunk.split_to(size);
self.items.push_front(chunk);
Ok(Async::Ready(Some(buf)))
}
else if size == chunk.len() {
} else if size == chunk.len() {
Ok(Async::Ready(Some(chunk)))
}
else {
} else {
let mut buf = BytesMut::with_capacity(size);
buf.extend_from_slice(&chunk);
@ -408,7 +411,7 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
let mut len = 0;
while len < size {
let mut chunk = self.items.pop_front().unwrap();
let rem = cmp::min(size-len, chunk.len());
let rem = cmp::min(size - len, chunk.len());
len += rem;
if rem < chunk.len() {
chunk.split_to(rem);
@ -427,7 +430,7 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
buf.extend_from_slice(&chunk[..rem]);
}
if buf.len() == size {
return Ok(Async::Ready(Some(buf)))
return Ok(Async::Ready(Some(buf)));
}
}
}
@ -454,8 +457,8 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
idx += 1;
if idx == line.len() {
num = no;
offset = pos+1;
length += pos+1;
offset = pos + 1;
length += pos + 1;
found = true;
break;
}
@ -483,7 +486,7 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
}
}
self.len -= length;
return Ok(Async::Ready(Some(buf.freeze())))
return Ok(Async::Ready(Some(buf.freeze())));
}
}
@ -505,171 +508,217 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
#[allow(dead_code)]
pub fn remaining(&mut self) -> Bytes {
self.items.iter_mut()
self.items
.iter_mut()
.fold(BytesMut::new(), |mut b, c| {
b.extend_from_slice(c);
b
}).freeze()
})
.freeze()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
use failure::Fail;
use futures::future::{lazy, result};
use std::io;
use tokio_core::reactor::Core;
#[test]
fn test_error() {
let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into();
let err: PayloadError =
io::Error::new(io::ErrorKind::Other, "ParseError").into();
assert_eq!(format!("{}", err), "ParseError");
assert_eq!(format!("{}", err.cause().unwrap()), "ParseError");
let err = PayloadError::Incomplete;
assert_eq!(format!("{}", err), "A payload reached EOF, but is not complete.");
assert_eq!(
format!("{}", err),
"A payload reached EOF, but is not complete."
);
}
#[test]
fn test_basic() {
Core::new().unwrap().run(lazy(|| {
let (_, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
Core::new()
.unwrap()
.run(lazy(|| {
let (_, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(payload.len, 0);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
assert_eq!(payload.len, 0);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
#[test]
fn test_eof() {
Core::new().unwrap().run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
Core::new()
.unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
sender.feed_data(Bytes::from("data"));
sender.feed_eof();
assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
sender.feed_data(Bytes::from("data"));
sender.feed_eof();
assert_eq!(Async::Ready(Some(Bytes::from("data"))),
payload.readany().ok().unwrap());
assert_eq!(payload.len, 0);
assert_eq!(Async::Ready(None), payload.readany().ok().unwrap());
assert_eq!(
Async::Ready(Some(Bytes::from("data"))),
payload.readany().ok().unwrap()
);
assert_eq!(payload.len, 0);
assert_eq!(Async::Ready(None), payload.readany().ok().unwrap());
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
#[test]
fn test_err() {
Core::new().unwrap().run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
Core::new()
.unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
sender.set_error(PayloadError::Incomplete);
payload.readany().err().unwrap();
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
sender.set_error(PayloadError::Incomplete);
payload.readany().err().unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
#[test]
fn test_readany() {
Core::new().unwrap().run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
Core::new()
.unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
assert_eq!(Async::Ready(Some(Bytes::from("line1"))),
payload.readany().ok().unwrap());
assert_eq!(payload.len, 0);
assert_eq!(
Async::Ready(Some(Bytes::from("line1"))),
payload.readany().ok().unwrap()
);
assert_eq!(payload.len, 0);
assert_eq!(Async::Ready(Some(Bytes::from("line2"))),
payload.readany().ok().unwrap());
assert_eq!(payload.len, 0);
assert_eq!(
Async::Ready(Some(Bytes::from("line2"))),
payload.readany().ok().unwrap()
);
assert_eq!(payload.len, 0);
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
#[test]
fn test_readexactly() {
Core::new().unwrap().run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
Core::new()
.unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap());
assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap());
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
assert_eq!(Async::Ready(Some(Bytes::from_static(b"li"))),
payload.read_exact(2).ok().unwrap());
assert_eq!(payload.len, 3);
assert_eq!(
Async::Ready(Some(Bytes::from_static(b"li"))),
payload.read_exact(2).ok().unwrap()
);
assert_eq!(payload.len, 3);
assert_eq!(Async::Ready(Some(Bytes::from_static(b"ne1l"))),
payload.read_exact(4).ok().unwrap());
assert_eq!(payload.len, 4);
assert_eq!(
Async::Ready(Some(Bytes::from_static(b"ne1l"))),
payload.read_exact(4).ok().unwrap()
);
assert_eq!(payload.len, 4);
sender.set_error(PayloadError::Incomplete);
payload.read_exact(10).err().unwrap();
sender.set_error(PayloadError::Incomplete);
payload.read_exact(10).err().unwrap();
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
#[test]
fn test_readuntil() {
Core::new().unwrap().run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
Core::new()
.unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.read_until(b"ne").ok().unwrap());
assert_eq!(
Async::NotReady,
payload.read_until(b"ne").ok().unwrap()
);
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
assert_eq!(Async::Ready(Some(Bytes::from("line"))),
payload.read_until(b"ne").ok().unwrap());
assert_eq!(payload.len, 1);
assert_eq!(
Async::Ready(Some(Bytes::from("line"))),
payload.read_until(b"ne").ok().unwrap()
);
assert_eq!(payload.len, 1);
assert_eq!(Async::Ready(Some(Bytes::from("1line2"))),
payload.read_until(b"2").ok().unwrap());
assert_eq!(payload.len, 0);
assert_eq!(
Async::Ready(Some(Bytes::from("1line2"))),
payload.read_until(b"2").ok().unwrap()
);
assert_eq!(payload.len, 0);
sender.set_error(PayloadError::Incomplete);
payload.read_until(b"b").err().unwrap();
sender.set_error(PayloadError::Incomplete);
payload.read_until(b"b").err().unwrap();
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
#[test]
fn test_unread_data() {
Core::new().unwrap().run(lazy(|| {
let (_, mut payload) = Payload::new(false);
Core::new()
.unwrap()
.run(lazy(|| {
let (_, mut payload) = Payload::new(false);
payload.unread_data(Bytes::from("data"));
assert!(!payload.is_empty());
assert_eq!(payload.len(), 4);
payload.unread_data(Bytes::from("data"));
assert!(!payload.is_empty());
assert_eq!(payload.len(), 4);
assert_eq!(Async::Ready(Some(Bytes::from("data"))),
payload.poll().ok().unwrap());
assert_eq!(
Async::Ready(Some(Bytes::from("data"))),
payload.poll().ok().unwrap()
);
let res: Result<(), ()> = Ok(());
result(res)
})).unwrap();
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
}
}

View File

@ -1,22 +1,22 @@
use std::{io, mem};
use std::rc::Rc;
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::rc::Rc;
use std::{io, mem};
use log::Level::Debug;
use futures::{Async, Poll, Future, Stream};
use futures::unsync::oneshot;
use futures::{Async, Future, Poll, Stream};
use log::Level::Debug;
use application::Inner;
use body::{Body, BodyStream};
use context::{Frame, ActorHttpContext};
use context::{ActorHttpContext, Frame};
use error::Error;
use header::ContentEncoding;
use handler::{Reply, ReplyItem};
use header::ContentEncoding;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Finished, Started, Response};
use application::Inner;
use server::{Writer, WriterState, HttpHandlerTask};
use middleware::{Finished, Middleware, Response, Started};
use server::{HttpHandlerTask, Writer, WriterState};
#[derive(Debug, Clone, Copy)]
pub(crate) enum HandlerType {
@ -26,7 +26,6 @@ pub(crate) enum HandlerType {
}
pub(crate) trait PipelineHandler<S> {
fn encoding(&self) -> ContentEncoding;
fn handle(&mut self, req: HttpRequest<S>, htype: HandlerType) -> Reply;
@ -46,7 +45,6 @@ enum PipelineState<S, H> {
}
impl<S: 'static, H: PipelineHandler<S>> PipelineState<S, H> {
fn is_response(&self) -> bool {
match *self {
PipelineState::Response(_) => true,
@ -61,10 +59,12 @@ impl<S: 'static, H: PipelineHandler<S>> PipelineState<S, H> {
PipelineState::RunMiddlewares(ref mut state) => state.poll(info),
PipelineState::Finishing(ref mut state) => state.poll(info),
PipelineState::Completed(ref mut state) => state.poll(info),
PipelineState::Response(_) | PipelineState::None | PipelineState::Error => None,
PipelineState::Response(_) | PipelineState::None | PipelineState::Error => {
None
}
}
}
}
}
struct PipelineInfo<S> {
req: HttpRequest<S>,
@ -92,7 +92,9 @@ impl<S> PipelineInfo<S> {
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))]
fn req_mut(&self) -> &mut HttpRequest<S> {
#[allow(mutable_transmutes)]
unsafe{mem::transmute(&self.req)}
unsafe {
mem::transmute(&self.req)
}
}
fn poll_context(&mut self) -> Poll<(), Error> {
@ -109,18 +111,18 @@ impl<S> PipelineInfo<S> {
}
impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> {
pub fn new(req: HttpRequest<S>,
mws: Rc<Vec<Box<Middleware<S>>>>,
handler: Rc<UnsafeCell<H>>, htype: HandlerType) -> Pipeline<S, H>
{
pub fn new(
req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>,
handler: Rc<UnsafeCell<H>>, htype: HandlerType,
) -> Pipeline<S, H> {
let mut info = PipelineInfo {
req, mws,
req,
mws,
count: 0,
error: None,
context: None,
disconnected: None,
encoding: unsafe{&*handler.get()}.encoding(),
encoding: unsafe { &*handler.get() }.encoding(),
};
let state = StartMiddlewares::init(&mut info, handler, htype);
@ -131,30 +133,33 @@ impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> {
impl Pipeline<(), Inner<()>> {
pub fn error<R: Into<HttpResponse>>(err: R) -> Box<HttpHandlerTask> {
Box::new(Pipeline::<(), Inner<()>>(
PipelineInfo::new(HttpRequest::default()), ProcessResponse::init(err.into())))
PipelineInfo::new(HttpRequest::default()),
ProcessResponse::init(err.into()),
))
}
}
impl<S: 'static, H> Pipeline<S, H> {
fn is_done(&self) -> bool {
match self.1 {
PipelineState::None | PipelineState::Error
| PipelineState::Starting(_) | PipelineState::Handler(_)
| PipelineState::RunMiddlewares(_) | PipelineState::Response(_) => true,
PipelineState::None
| PipelineState::Error
| PipelineState::Starting(_)
| PipelineState::Handler(_)
| PipelineState::RunMiddlewares(_)
| PipelineState::Response(_) => true,
PipelineState::Finishing(_) | PipelineState::Completed(_) => false,
}
}
}
impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
fn disconnected(&mut self) {
self.0.disconnected = Some(true);
}
fn poll_io(&mut self, io: &mut Writer) -> Poll<bool, Error> {
let info: &mut PipelineInfo<_> = unsafe{ mem::transmute(&mut self.0) };
let info: &mut PipelineInfo<_> = unsafe { mem::transmute(&mut self.0) };
loop {
if self.1.is_response() {
@ -164,9 +169,9 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
Ok(state) => {
self.1 = state;
if let Some(error) = self.0.error.take() {
return Err(error)
return Err(error);
} else {
return Ok(Async::Ready(self.is_done()))
return Ok(Async::Ready(self.is_done()));
}
}
Err(state) => {
@ -177,11 +182,10 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
}
}
match self.1 {
PipelineState::None =>
return Ok(Async::Ready(true)),
PipelineState::Error =>
return Err(io::Error::new(
io::ErrorKind::Other, "Internal error").into()),
PipelineState::None => return Ok(Async::Ready(true)),
PipelineState::Error => {
return Err(io::Error::new(io::ErrorKind::Other, "Internal error").into())
}
_ => (),
}
@ -193,7 +197,7 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
}
fn poll(&mut self) -> Poll<(), Error> {
let info: &mut PipelineInfo<_> = unsafe{ mem::transmute(&mut self.0) };
let info: &mut PipelineInfo<_> = unsafe { mem::transmute(&mut self.0) };
loop {
match self.1 {
@ -212,7 +216,7 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
}
}
type Fut = Box<Future<Item=Option<HttpResponse>, Error=Error>>;
type Fut = Box<Future<Item = Option<HttpResponse>, Error = Error>>;
/// Middlewares start executor
struct StartMiddlewares<S, H> {
@ -223,41 +227,40 @@ struct StartMiddlewares<S, H> {
}
impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
fn init(info: &mut PipelineInfo<S>, hnd: Rc<UnsafeCell<H>>, htype: HandlerType)
-> PipelineState<S, H>
{
// execute middlewares, we need this stage because middlewares could be non-async
// and we can move to next state immediately
fn init(
info: &mut PipelineInfo<S>, hnd: Rc<UnsafeCell<H>>, htype: HandlerType
) -> PipelineState<S, H> {
// execute middlewares, we need this stage because middlewares could be
// non-async and we can move to next state immediately
let len = info.mws.len() as u16;
loop {
if info.count == len {
let reply = unsafe{&mut *hnd.get()}.handle(info.req.clone(), htype);
return WaitingResponse::init(info, reply)
let reply = unsafe { &mut *hnd.get() }.handle(info.req.clone(), htype);
return WaitingResponse::init(info, reply);
} else {
match info.mws[info.count as usize].start(&mut info.req) {
Ok(Started::Done) =>
info.count += 1,
Ok(Started::Response(resp)) =>
return RunMiddlewares::init(info, resp),
Ok(Started::Future(mut fut)) =>
match fut.poll() {
Ok(Async::NotReady) =>
return PipelineState::Starting(StartMiddlewares {
hnd, htype,
fut: Some(fut),
_s: PhantomData}),
Ok(Async::Ready(resp)) => {
if let Some(resp) = resp {
return RunMiddlewares::init(info, resp);
}
info.count += 1;
Ok(Started::Done) => info.count += 1,
Ok(Started::Response(resp)) => {
return RunMiddlewares::init(info, resp)
}
Ok(Started::Future(mut fut)) => match fut.poll() {
Ok(Async::NotReady) => {
return PipelineState::Starting(StartMiddlewares {
hnd,
htype,
fut: Some(fut),
_s: PhantomData,
})
}
Ok(Async::Ready(resp)) => {
if let Some(resp) = resp {
return RunMiddlewares::init(info, resp);
}
Err(err) =>
return ProcessResponse::init(err.into()),
},
Err(err) =>
return ProcessResponse::init(err.into()),
info.count += 1;
}
Err(err) => return ProcessResponse::init(err.into()),
},
Err(err) => return ProcessResponse::init(err.into()),
}
}
}
@ -274,29 +277,28 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
return Some(RunMiddlewares::init(info, resp));
}
if info.count == len {
let reply = unsafe{
&mut *self.hnd.get()}.handle(info.req.clone(), self.htype);
let reply = unsafe { &mut *self.hnd.get() }
.handle(info.req.clone(), self.htype);
return Some(WaitingResponse::init(info, reply));
} else {
loop {
match info.mws[info.count as usize].start(info.req_mut()) {
Ok(Started::Done) =>
info.count += 1,
Ok(Started::Done) => info.count += 1,
Ok(Started::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp));
},
}
Ok(Started::Future(fut)) => {
self.fut = Some(fut);
continue 'outer
},
Err(err) =>
continue 'outer;
}
Err(err) => {
return Some(ProcessResponse::init(err.into()))
}
}
}
}
}
Err(err) =>
return Some(ProcessResponse::init(err.into()))
Err(err) => return Some(ProcessResponse::init(err.into())),
}
}
}
@ -304,31 +306,29 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
// waiting for response
struct WaitingResponse<S, H> {
fut: Box<Future<Item=HttpResponse, Error=Error>>,
fut: Box<Future<Item = HttpResponse, Error = Error>>,
_s: PhantomData<S>,
_h: PhantomData<H>,
}
impl<S: 'static, H> WaitingResponse<S, H> {
#[inline]
fn init(info: &mut PipelineInfo<S>, reply: Reply) -> PipelineState<S, H> {
match reply.into() {
ReplyItem::Message(resp) =>
RunMiddlewares::init(info, resp),
ReplyItem::Future(fut) =>
PipelineState::Handler(
WaitingResponse { fut, _s: PhantomData, _h: PhantomData }),
ReplyItem::Message(resp) => RunMiddlewares::init(info, resp),
ReplyItem::Future(fut) => PipelineState::Handler(WaitingResponse {
fut,
_s: PhantomData,
_h: PhantomData,
}),
}
}
fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> {
match self.fut.poll() {
Ok(Async::NotReady) => None,
Ok(Async::Ready(response)) =>
Some(RunMiddlewares::init(info, response)),
Err(err) =>
Some(ProcessResponse::init(err.into())),
Ok(Async::Ready(response)) => Some(RunMiddlewares::init(info, response)),
Err(err) => Some(ProcessResponse::init(err.into())),
}
}
}
@ -336,13 +336,12 @@ impl<S: 'static, H> WaitingResponse<S, H> {
/// Middlewares response executor
struct RunMiddlewares<S, H> {
curr: usize,
fut: Option<Box<Future<Item=HttpResponse, Error=Error>>>,
fut: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
_s: PhantomData<S>,
_h: PhantomData<H>,
}
impl<S: 'static, H> RunMiddlewares<S, H> {
fn init(info: &mut PipelineInfo<S>, mut resp: HttpResponse) -> PipelineState<S, H> {
if info.count == 0 {
return ProcessResponse::init(resp);
@ -354,21 +353,24 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
resp = match info.mws[curr].response(info.req_mut(), resp) {
Err(err) => {
info.count = (curr + 1) as u16;
return ProcessResponse::init(err.into())
return ProcessResponse::init(err.into());
}
Ok(Response::Done(r)) => {
curr += 1;
if curr == len {
return ProcessResponse::init(r)
return ProcessResponse::init(r);
} else {
r
}
},
}
Ok(Response::Future(fut)) => {
return PipelineState::RunMiddlewares(
RunMiddlewares { curr, fut: Some(fut),
_s: PhantomData, _h: PhantomData })
},
return PipelineState::RunMiddlewares(RunMiddlewares {
curr,
fut: Some(fut),
_s: PhantomData,
_h: PhantomData,
})
}
};
}
}
@ -379,15 +381,12 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
loop {
// poll latest fut
let mut resp = match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => {
return None
}
Ok(Async::NotReady) => return None,
Ok(Async::Ready(resp)) => {
self.curr += 1;
resp
}
Err(err) =>
return Some(ProcessResponse::init(err.into())),
Err(err) => return Some(ProcessResponse::init(err.into())),
};
loop {
@ -395,16 +394,15 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
return Some(ProcessResponse::init(resp));
} else {
match info.mws[self.curr].response(info.req_mut(), resp) {
Err(err) =>
return Some(ProcessResponse::init(err.into())),
Err(err) => return Some(ProcessResponse::init(err.into())),
Ok(Response::Done(r)) => {
self.curr += 1;
resp = r
},
}
Ok(Response::Future(fut)) => {
self.fut = Some(fut);
break
},
break;
}
}
}
}
@ -451,42 +449,56 @@ enum IOState {
}
impl<S: 'static, H> ProcessResponse<S, H> {
#[inline]
fn init(resp: HttpResponse) -> PipelineState<S, H> {
PipelineState::Response(
ProcessResponse{ resp,
iostate: IOState::Response,
running: RunningState::Running,
drain: None, _s: PhantomData, _h: PhantomData})
PipelineState::Response(ProcessResponse {
resp,
iostate: IOState::Response,
running: RunningState::Running,
drain: None,
_s: PhantomData,
_h: PhantomData,
})
}
fn poll_io(mut self, io: &mut Writer, info: &mut PipelineInfo<S>)
-> Result<PipelineState<S, H>, PipelineState<S, H>>
{
fn poll_io(
mut self, io: &mut Writer, info: &mut PipelineInfo<S>
) -> Result<PipelineState<S, H>, PipelineState<S, H>> {
loop {
if self.drain.is_none() && self.running != RunningState::Paused {
// if task is paused, write buffer is probably full
'inner: loop {
let result = match mem::replace(&mut self.iostate, IOState::Done) {
IOState::Response => {
let encoding = self.resp.content_encoding().unwrap_or(info.encoding);
let encoding =
self.resp.content_encoding().unwrap_or(info.encoding);
let result = match io.start(info.req_mut().get_inner(),
&mut self.resp, encoding)
{
let result = match io.start(
info.req_mut().get_inner(),
&mut self.resp,
encoding,
) {
Ok(res) => res,
Err(err) => {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
}
};
if let Some(err) = self.resp.error() {
if self.resp.status().is_server_error() {
error!("Error occured during request handling: {}", err);
error!(
"Error occured during request handling: {}",
err
);
} else {
warn!("Error occured during request handling: {}", err);
warn!(
"Error occured during request handling: {}",
err
);
}
if log_enabled!(Debug) {
debug!("{:?}", err);
@ -497,44 +509,48 @@ impl<S: 'static, H> ProcessResponse<S, H> {
match self.resp.replace_body(Body::Empty) {
Body::Streaming(stream) => {
self.iostate = IOState::Payload(stream);
continue 'inner
},
Body::Actor(ctx) => {
self.iostate = IOState::Actor(ctx);
continue 'inner
},
continue 'inner;
}
Body::Actor(ctx) => {
self.iostate = IOState::Actor(ctx);
continue 'inner;
}
_ => (),
}
result
},
IOState::Payload(mut body) => {
match body.poll() {
Ok(Async::Ready(None)) => {
if let Err(err) = io.write_eof() {
}
IOState::Payload(mut body) => match body.poll() {
Ok(Async::Ready(None)) => {
if let Err(err) = io.write_eof() {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
}
break;
}
Ok(Async::Ready(Some(chunk))) => {
self.iostate = IOState::Payload(body);
match io.write(chunk.into()) {
Err(err) => {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
}
break
},
Ok(Async::Ready(Some(chunk))) => {
self.iostate = IOState::Payload(body);
match io.write(chunk.into()) {
Err(err) => {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
},
Ok(result) => result
return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
}
Ok(result) => result,
}
Ok(Async::NotReady) => {
self.iostate = IOState::Payload(body);
break
},
Err(err) => {
info.error = Some(err);
return Ok(FinishingMiddlewares::init(info, self.resp))
}
}
Ok(Async::NotReady) => {
self.iostate = IOState::Payload(body);
break;
}
Err(err) => {
info.error = Some(err);
return Ok(FinishingMiddlewares::init(info, self.resp));
}
},
IOState::Actor(mut ctx) => {
@ -545,7 +561,7 @@ impl<S: 'static, H> ProcessResponse<S, H> {
Ok(Async::Ready(Some(vec))) => {
if vec.is_empty() {
self.iostate = IOState::Actor(ctx);
break
break;
}
let mut res = None;
for frame in vec {
@ -555,40 +571,49 @@ impl<S: 'static, H> ProcessResponse<S, H> {
if let Err(err) = io.write_eof() {
info.error = Some(err.into());
return Ok(
FinishingMiddlewares::init(info, self.resp))
FinishingMiddlewares::init(
info,
self.resp,
),
);
}
break 'inner
},
break 'inner;
}
Frame::Chunk(Some(chunk)) => {
match io.write(chunk) {
Err(err) => {
info.error = Some(err.into());
return Ok(
FinishingMiddlewares::init(info, self.resp))
},
FinishingMiddlewares::init(
info,
self.resp,
),
);
}
Ok(result) => res = Some(result),
}
},
}
Frame::Drain(fut) => self.drain = Some(fut),
}
}
self.iostate = IOState::Actor(ctx);
if self.drain.is_some() {
self.running.resume();
break 'inner
break 'inner;
}
res.unwrap()
},
Ok(Async::Ready(None)) => {
break
}
Ok(Async::Ready(None)) => break,
Ok(Async::NotReady) => {
self.iostate = IOState::Actor(ctx);
break
break;
}
Err(err) => {
info.error = Some(err);
return Ok(FinishingMiddlewares::init(info, self.resp))
return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
}
}
}
@ -598,11 +623,9 @@ impl<S: 'static, H> ProcessResponse<S, H> {
match result {
WriterState::Pause => {
self.running.pause();
break
break;
}
WriterState::Done => {
self.running.resume()
},
WriterState::Done => self.running.resume(),
}
}
}
@ -618,17 +641,16 @@ impl<S: 'static, H> ProcessResponse<S, H> {
let _ = tx.send(());
}
// restart io processing
continue
},
Ok(Async::NotReady) =>
return Err(PipelineState::Response(self)),
continue;
}
Ok(Async::NotReady) => return Err(PipelineState::Response(self)),
Err(err) => {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
return Ok(FinishingMiddlewares::init(info, self.resp));
}
}
}
break
break;
}
// response is completed
@ -638,7 +660,7 @@ impl<S: 'static, H> ProcessResponse<S, H> {
Ok(_) => (),
Err(err) => {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
return Ok(FinishingMiddlewares::init(info, self.resp));
}
}
self.resp.set_response_size(io.written());
@ -652,19 +674,22 @@ impl<S: 'static, H> ProcessResponse<S, H> {
/// Middlewares start executor
struct FinishingMiddlewares<S, H> {
resp: HttpResponse,
fut: Option<Box<Future<Item=(), Error=Error>>>,
fut: Option<Box<Future<Item = (), Error = Error>>>,
_s: PhantomData<S>,
_h: PhantomData<H>,
}
impl<S: 'static, H> FinishingMiddlewares<S, H> {
fn init(info: &mut PipelineInfo<S>, resp: HttpResponse) -> PipelineState<S, H> {
if info.count == 0 {
Completed::init(info)
} else {
let mut state = FinishingMiddlewares{resp, fut: None,
_s: PhantomData, _h: PhantomData};
let mut state = FinishingMiddlewares {
resp,
fut: None,
_s: PhantomData,
_h: PhantomData,
};
if let Some(st) = state.poll(info) {
st
} else {
@ -678,12 +703,8 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
// poll latest fut
let not_ready = if let Some(ref mut fut) = self.fut {
match fut.poll() {
Ok(Async::NotReady) => {
true
},
Ok(Async::Ready(())) => {
false
},
Ok(Async::NotReady) => true,
Ok(Async::Ready(())) => false,
Err(err) => {
error!("Middleware finish error: {}", err);
false
@ -701,12 +722,12 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
match info.mws[info.count as usize].finish(info.req_mut(), &self.resp) {
Finished::Done => {
if info.count == 0 {
return Some(Completed::init(info))
return Some(Completed::init(info));
}
}
Finished::Future(fut) => {
self.fut = Some(fut);
},
}
}
}
}
@ -716,7 +737,6 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
struct Completed<S, H>(PhantomData<S>, PhantomData<H>);
impl<S, H> Completed<S, H> {
#[inline]
fn init(info: &mut PipelineInfo<S>) -> PipelineState<S, H> {
if let Some(ref err) = info.error {
@ -745,15 +765,23 @@ mod tests {
use super::*;
use actix::*;
use context::HttpContext;
use tokio_core::reactor::Core;
use futures::future::{lazy, result};
use tokio_core::reactor::Core;
impl<S, H> PipelineState<S, H> {
fn is_none(&self) -> Option<bool> {
if let PipelineState::None = *self { Some(true) } else { None }
if let PipelineState::None = *self {
Some(true)
} else {
None
}
}
fn completed(self) -> Option<Completed<S, H>> {
if let PipelineState::Completed(c) = self { Some(c) } else { None }
if let PipelineState::Completed(c) = self {
Some(c)
} else {
None
}
}
}
@ -764,28 +792,35 @@ mod tests {
#[test]
fn test_completed() {
Core::new().unwrap().run(lazy(|| {
let mut info = PipelineInfo::new(HttpRequest::default());
Completed::<(), Inner<()>>::init(&mut info).is_none().unwrap();
Core::new()
.unwrap()
.run(lazy(|| {
let mut info = PipelineInfo::new(HttpRequest::default());
Completed::<(), Inner<()>>::init(&mut info)
.is_none()
.unwrap();
let req = HttpRequest::default();
let mut ctx = HttpContext::new(req.clone(), MyActor);
let addr: Addr<Unsync, _> = ctx.address();
let mut info = PipelineInfo::new(req);
info.context = Some(Box::new(ctx));
let mut state = Completed::<(), Inner<()>>::init(&mut info).completed().unwrap();
let req = HttpRequest::default();
let mut ctx = HttpContext::new(req.clone(), MyActor);
let addr: Addr<Unsync, _> = ctx.address();
let mut info = PipelineInfo::new(req);
info.context = Some(Box::new(ctx));
let mut state = Completed::<(), Inner<()>>::init(&mut info)
.completed()
.unwrap();
assert!(state.poll(&mut info).is_none());
let pp = Pipeline(info, PipelineState::Completed(state));
assert!(!pp.is_done());
assert!(state.poll(&mut info).is_none());
let pp = Pipeline(info, PipelineState::Completed(state));
assert!(!pp.is_done());
let Pipeline(mut info, st) = pp;
let mut st = st.completed().unwrap();
drop(addr);
let Pipeline(mut info, st) = pp;
let mut st = st.completed().unwrap();
drop(addr);
assert!(st.poll(&mut info).unwrap().is_none().unwrap());
assert!(st.poll(&mut info).unwrap().is_none().unwrap());
result(Ok::<_, ()>(()))
})).unwrap();
result(Ok::<_, ()>(()))
}))
.unwrap();
}
}

View File

@ -1,20 +1,18 @@
//! Route match predicates
#![allow(non_snake_case)]
use std::marker::PhantomData;
use http;
use http::{header, HttpTryFrom};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use std::marker::PhantomData;
/// Trait defines resource route predicate.
/// Predicate can modify request object. It is also possible to
/// to store extra attributes on request by using `Extensions` container,
/// Extensions container available via `HttpRequest::extensions()` method.
pub trait Predicate<S> {
/// Check if request matches predicate
fn check(&self, &mut HttpRequest<S>) -> bool;
}
/// Return predicate that matches if any of supplied predicate matches.
@ -30,8 +28,7 @@ pub trait Predicate<S> {
/// .f(|r| HttpResponse::MethodNotAllowed()));
/// }
/// ```
pub fn Any<S: 'static, P: Predicate<S> + 'static>(pred: P) -> AnyPredicate<S>
{
pub fn Any<S: 'static, P: Predicate<S> + 'static>(pred: P) -> AnyPredicate<S> {
AnyPredicate(vec![Box::new(pred)])
}
@ -50,7 +47,7 @@ impl<S: 'static> Predicate<S> for AnyPredicate<S> {
fn check(&self, req: &mut HttpRequest<S>) -> bool {
for p in &self.0 {
if p.check(req) {
return true
return true;
}
}
false
@ -90,7 +87,7 @@ impl<S: 'static> Predicate<S> for AllPredicate<S> {
fn check(&self, req: &mut HttpRequest<S>) -> bool {
for p in &self.0 {
if !p.check(req) {
return false
return false;
}
}
true
@ -98,8 +95,7 @@ impl<S: 'static> Predicate<S> for AllPredicate<S> {
}
/// Return predicate that matches if supplied predicate does not match.
pub fn Not<S: 'static, P: Predicate<S> + 'static>(pred: P) -> NotPredicate<S>
{
pub fn Not<S: 'static, P: Predicate<S> + 'static>(pred: P) -> NotPredicate<S> {
NotPredicate(Box::new(pred))
}
@ -172,21 +168,29 @@ pub fn Method<S: 'static>(method: http::Method) -> MethodPredicate<S> {
MethodPredicate(method, PhantomData)
}
/// Return predicate that matches if request contains specified header and value.
pub fn Header<S: 'static>(name: &'static str, value: &'static str) -> HeaderPredicate<S>
{
HeaderPredicate(header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
PhantomData)
/// Return predicate that matches if request contains specified header and
/// value.
pub fn Header<S: 'static>(
name: &'static str, value: &'static str
) -> HeaderPredicate<S> {
HeaderPredicate(
header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
PhantomData,
)
}
#[doc(hidden)]
pub struct HeaderPredicate<S>(header::HeaderName, header::HeaderValue, PhantomData<S>);
pub struct HeaderPredicate<S>(
header::HeaderName,
header::HeaderValue,
PhantomData<S>,
);
impl<S: 'static> Predicate<S> for HeaderPredicate<S> {
fn check(&self, req: &mut HttpRequest<S>) -> bool {
if let Some(val) = req.headers().get(&self.0) {
return val == self.1
return val == self.1;
}
false
}
@ -195,17 +199,24 @@ impl<S: 'static> Predicate<S> for HeaderPredicate<S> {
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
use http::{Uri, Version, Method};
use http::header::{self, HeaderMap};
use http::{Method, Uri, Version};
use std::str::FromStr;
#[test]
fn test_header() {
let mut headers = HeaderMap::new();
headers.insert(header::TRANSFER_ENCODING,
header::HeaderValue::from_static("chunked"));
headers.insert(
header::TRANSFER_ENCODING,
header::HeaderValue::from_static("chunked"),
);
let mut req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None);
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let pred = Header("transfer-encoding", "chunked");
assert!(pred.check(&mut req));
@ -220,11 +231,19 @@ mod tests {
#[test]
fn test_methods() {
let mut req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
let mut req2 = HttpRequest::new(
Method::POST, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::POST,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Get().check(&mut req));
assert!(!Get().check(&mut req2));
@ -232,44 +251,72 @@ mod tests {
assert!(!Post().check(&mut req));
let mut r = HttpRequest::new(
Method::PUT, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::PUT,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Put().check(&mut r));
assert!(!Put().check(&mut req));
let mut r = HttpRequest::new(
Method::DELETE, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::DELETE,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Delete().check(&mut r));
assert!(!Delete().check(&mut req));
let mut r = HttpRequest::new(
Method::HEAD, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::HEAD,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Head().check(&mut r));
assert!(!Head().check(&mut req));
let mut r = HttpRequest::new(
Method::OPTIONS, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::OPTIONS,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Options().check(&mut r));
assert!(!Options().check(&mut req));
let mut r = HttpRequest::new(
Method::CONNECT, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::CONNECT,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Connect().check(&mut r));
assert!(!Connect().check(&mut req));
let mut r = HttpRequest::new(
Method::PATCH, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::PATCH,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Patch().check(&mut r));
assert!(!Patch().check(&mut req));
let mut r = HttpRequest::new(
Method::TRACE, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::TRACE,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Trace().check(&mut r));
assert!(!Trace().check(&mut req));
}
@ -277,8 +324,12 @@ mod tests {
#[test]
fn test_preds() {
let mut r = HttpRequest::new(
Method::TRACE, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
Method::TRACE,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Not(Get()).check(&mut r));
assert!(!Not(Trace()).check(&mut r));

View File

@ -1,15 +1,15 @@
use std::rc::Rc;
use std::marker::PhantomData;
use std::rc::Rc;
use smallvec::SmallVec;
use http::{Method, StatusCode};
use smallvec::SmallVec;
use pred;
use route::Route;
use handler::{Reply, Handler, Responder, FromRequest};
use middleware::Middleware;
use handler::{FromRequest, Handler, Reply, Responder};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::Middleware;
use pred;
use route::Route;
/// *Resource* is an entry in route table which corresponds to requested URL.
///
@ -18,8 +18,8 @@ use httpresponse::HttpResponse;
/// and list of predicates (objects that implement `Predicate` trait).
/// Route uses builder-like pattern for configuration.
/// During request handling, resource object iterate through all routes
/// and check all predicates for specific route, if request matches all predicates route
/// route considered matched and route handler get called.
/// and check all predicates for specific route, if request matches all
/// predicates route route considered matched and route handler get called.
///
/// ```rust
/// # extern crate actix_web;
@ -31,7 +31,7 @@ use httpresponse::HttpResponse;
/// "/", |r| r.method(http::Method::GET).f(|r| HttpResponse::Ok()))
/// .finish();
/// }
pub struct ResourceHandler<S=()> {
pub struct ResourceHandler<S = ()> {
name: String,
state: PhantomData<S>,
routes: SmallVec<[Route<S>; 3]>,
@ -44,18 +44,19 @@ impl<S> Default for ResourceHandler<S> {
name: String::new(),
state: PhantomData,
routes: SmallVec::new(),
middlewares: Rc::new(Vec::new()) }
middlewares: Rc::new(Vec::new()),
}
}
}
impl<S> ResourceHandler<S> {
pub(crate) fn default_not_found() -> Self {
ResourceHandler {
name: String::new(),
state: PhantomData,
routes: SmallVec::new(),
middlewares: Rc::new(Vec::new()) }
middlewares: Rc::new(Vec::new()),
}
}
/// Set resource name
@ -69,9 +70,9 @@ impl<S> ResourceHandler<S> {
}
impl<S: 'static> ResourceHandler<S> {
/// Register a new route and return mutable reference to *Route* object.
/// *Route* is used for route configuration, i.e. adding predicates, setting up handler.
/// *Route* is used for route configuration, i.e. adding predicates,
/// setting up handler.
///
/// ```rust
/// # extern crate actix_web;
@ -131,7 +132,10 @@ impl<S: 'static> ResourceHandler<S> {
/// ```
pub fn method(&mut self, method: Method) -> &mut Route<S> {
self.routes.push(Route::default());
self.routes.last_mut().unwrap().filter(pred::Method(method))
self.routes
.last_mut()
.unwrap()
.filter(pred::Method(method))
}
/// Register a new route and add handler object.
@ -154,8 +158,9 @@ impl<S: 'static> ResourceHandler<S> {
/// Application::resource("/", |r| r.route().f(index)
/// ```
pub fn f<F, R>(&mut self, handler: F)
where F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
where
F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
{
self.routes.push(Route::default());
self.routes.last_mut().unwrap().f(handler)
@ -169,9 +174,10 @@ impl<S: 'static> ResourceHandler<S> {
/// Application::resource("/", |r| r.route().with(index)
/// ```
pub fn with<T, F, R>(&mut self, handler: F)
where F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
where
F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
{
self.routes.push(Route::default());
self.routes.last_mut().unwrap().with(handler);
@ -182,13 +188,14 @@ impl<S: 'static> ResourceHandler<S> {
/// This is similar to `App's` middlewares, but
/// middlewares get invoked on resource level.
pub fn middleware<M: Middleware<S>>(&mut self, mw: M) {
Rc::get_mut(&mut self.middlewares).unwrap().push(Box::new(mw));
Rc::get_mut(&mut self.middlewares)
.unwrap()
.push(Box::new(mw));
}
pub(crate) fn handle(&mut self,
mut req: HttpRequest<S>,
default: Option<&mut ResourceHandler<S>>) -> Reply
{
pub(crate) fn handle(
&mut self, mut req: HttpRequest<S>, default: Option<&mut ResourceHandler<S>>
) -> Reply {
for route in &mut self.routes {
if route.check(&mut req) {
return if self.middlewares.is_empty() {

View File

@ -1,17 +1,18 @@
use futures::{Async, Future, Poll};
use std::marker::PhantomData;
use std::mem;
use std::rc::Rc;
use std::marker::PhantomData;
use futures::{Async, Future, Poll};
use error::Error;
use pred::Predicate;
use handler::{AsyncHandler, FromRequest, Handler, Reply, ReplyItem, Responder,
RouteHandler, WrapHandler};
use http::StatusCode;
use handler::{Reply, ReplyItem, Handler, FromRequest,
Responder, RouteHandler, AsyncHandler, WrapHandler};
use middleware::{Middleware, Response as MiddlewareResponse, Started as MiddlewareStarted};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use with::{With, With2, With3, ExtractorConfig};
use middleware::{Middleware, Response as MiddlewareResponse,
Started as MiddlewareStarted};
use pred::Predicate;
use with::{ExtractorConfig, With, With2, With3};
/// Resource route definition
///
@ -23,7 +24,6 @@ pub struct Route<S> {
}
impl<S: 'static> Default for Route<S> {
fn default() -> Route<S> {
Route {
preds: Vec::new(),
@ -33,12 +33,11 @@ impl<S: 'static> Default for Route<S> {
}
impl<S: 'static> Route<S> {
#[inline]
pub(crate) fn check(&self, req: &mut HttpRequest<S>) -> bool {
for pred in &self.preds {
if !pred.check(req) {
return false
return false;
}
}
true
@ -50,9 +49,9 @@ impl<S: 'static> Route<S> {
}
#[inline]
pub(crate) fn compose(&mut self,
req: HttpRequest<S>,
mws: Rc<Vec<Box<Middleware<S>>>>) -> Reply {
pub(crate) fn compose(
&mut self, req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>
) -> Reply {
Reply::async(Compose::new(req, mws, self.handler.clone()))
}
@ -86,18 +85,20 @@ impl<S: 'static> Route<S> {
/// Set handler function. Usually call to this method is last call
/// during route configuration, so it does not return reference to self.
pub fn f<F, R>(&mut self, handler: F)
where F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
where
F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
{
self.handler = InnerHandler::new(handler);
}
/// Set async handler function.
pub fn a<H, R, F, E>(&mut self, handler: H)
where H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static
where
H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item = R, Error = E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
{
self.handler = InnerHandler::async(handler);
}
@ -128,9 +129,10 @@ impl<S: 'static> Route<S> {
/// }
/// ```
pub fn with<T, F, R>(&mut self, handler: F) -> ExtractorConfig<S, T>
where F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
where
F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
{
let cfg = ExtractorConfig::default();
self.h(With::new(handler, Clone::clone(&cfg)));
@ -167,43 +169,58 @@ impl<S: 'static> Route<S> {
/// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor
/// }
/// ```
pub fn with2<T1, T2, F, R>(&mut self, handler: F)
-> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>)
where F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
pub fn with2<T1, T2, F, R>(
&mut self, handler: F
) -> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>)
where
F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
{
let cfg1 = ExtractorConfig::default();
let cfg2 = ExtractorConfig::default();
self.h(With2::new(handler, Clone::clone(&cfg1), Clone::clone(&cfg2)));
self.h(With2::new(
handler,
Clone::clone(&cfg1),
Clone::clone(&cfg2),
));
(cfg1, cfg2)
}
/// Set handler function, use request extractor for all paramters.
pub fn with3<T1, T2, T3, F, R>(&mut self, handler: F)
-> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>, ExtractorConfig<S, T3>)
where F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
pub fn with3<T1, T2, T3, F, R>(
&mut self, handler: F
) -> (
ExtractorConfig<S, T1>,
ExtractorConfig<S, T2>,
ExtractorConfig<S, T3>,
)
where
F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
{
let cfg1 = ExtractorConfig::default();
let cfg2 = ExtractorConfig::default();
let cfg3 = ExtractorConfig::default();
self.h(With3::new(
handler, Clone::clone(&cfg1), Clone::clone(&cfg2), Clone::clone(&cfg3)));
handler,
Clone::clone(&cfg1),
Clone::clone(&cfg2),
Clone::clone(&cfg3),
));
(cfg1, cfg2, cfg3)
}
}
/// `RouteHandler` wrapper. This struct is required because it needs to be shared
/// for resource level middlewares.
/// `RouteHandler` wrapper. This struct is required because it needs to be
/// shared for resource level middlewares.
struct InnerHandler<S>(Rc<Box<RouteHandler<S>>>);
impl<S: 'static> InnerHandler<S> {
#[inline]
fn new<H: Handler<S>>(h: H) -> Self {
InnerHandler(Rc::new(Box::new(WrapHandler::new(h))))
@ -211,10 +228,11 @@ impl<S: 'static> InnerHandler<S> {
#[inline]
fn async<H, R, F, E>(h: H) -> Self
where H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static
where
H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item = R, Error = E> + 'static,
R: Responder + 'static,
E: Into<Error> + 'static,
{
InnerHandler(Rc::new(Box::new(AsyncHandler::new(h))))
}
@ -237,7 +255,6 @@ impl<S> Clone for InnerHandler<S> {
}
}
/// Compose resource level middlewares with route handler.
struct Compose<S: 'static> {
info: ComposeInfo<S>,
@ -270,14 +287,18 @@ impl<S: 'static> ComposeState<S> {
}
impl<S: 'static> Compose<S> {
fn new(req: HttpRequest<S>,
mws: Rc<Vec<Box<Middleware<S>>>>,
handler: InnerHandler<S>) -> Self
{
let mut info = ComposeInfo { count: 0, req, mws, handler };
fn new(
req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, handler: InnerHandler<S>
) -> Self {
let mut info = ComposeInfo {
count: 0,
req,
mws,
handler,
};
let state = StartMiddlewares::init(&mut info);
Compose {state, info}
Compose { state, info }
}
}
@ -289,12 +310,12 @@ impl<S> Future for Compose<S> {
loop {
if let ComposeState::Response(ref mut resp) = self.state {
let resp = resp.resp.take().unwrap();
return Ok(Async::Ready(resp))
return Ok(Async::Ready(resp));
}
if let Some(state) = self.state.poll(&mut self.info) {
self.state = state;
} else {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
}
}
@ -306,51 +327,47 @@ struct StartMiddlewares<S> {
_s: PhantomData<S>,
}
type Fut = Box<Future<Item=Option<HttpResponse>, Error=Error>>;
type Fut = Box<Future<Item = Option<HttpResponse>, Error = Error>>;
impl<S: 'static> StartMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> {
let len = info.mws.len();
loop {
if info.count == len {
let reply = info.handler.handle(info.req.clone());
return WaitingResponse::init(info, reply)
return WaitingResponse::init(info, reply);
} else {
match info.mws[info.count].start(&mut info.req) {
Ok(MiddlewareStarted::Done) =>
info.count += 1,
Ok(MiddlewareStarted::Response(resp)) =>
return RunMiddlewares::init(info, resp),
Ok(MiddlewareStarted::Future(mut fut)) =>
match fut.poll() {
Ok(Async::NotReady) =>
return ComposeState::Starting(StartMiddlewares {
fut: Some(fut),
_s: PhantomData}),
Ok(Async::Ready(resp)) => {
if let Some(resp) = resp {
return RunMiddlewares::init(info, resp);
}
info.count += 1;
Ok(MiddlewareStarted::Done) => info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => {
return RunMiddlewares::init(info, resp)
}
Ok(MiddlewareStarted::Future(mut fut)) => match fut.poll() {
Ok(Async::NotReady) => {
return ComposeState::Starting(StartMiddlewares {
fut: Some(fut),
_s: PhantomData,
})
}
Ok(Async::Ready(resp)) => {
if let Some(resp) = resp {
return RunMiddlewares::init(info, resp);
}
Err(err) =>
return Response::init(err.into()),
},
Err(err) =>
return Response::init(err.into()),
info.count += 1;
}
Err(err) => return Response::init(err.into()),
},
Err(err) => return Response::init(err.into()),
}
}
}
}
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>>
{
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
let len = info.mws.len();
'outer: loop {
match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) =>
return None,
Ok(Async::NotReady) => return None,
Ok(Async::Ready(resp)) => {
info.count += 1;
if let Some(resp) = resp {
@ -362,23 +379,20 @@ impl<S: 'static> StartMiddlewares<S> {
} else {
loop {
match info.mws[info.count].start(&mut info.req) {
Ok(MiddlewareStarted::Done) =>
info.count += 1,
Ok(MiddlewareStarted::Done) => info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp));
},
}
Ok(MiddlewareStarted::Future(fut)) => {
self.fut = Some(fut);
continue 'outer
},
Err(err) =>
return Some(Response::init(err.into()))
continue 'outer;
}
Err(err) => return Some(Response::init(err.into())),
}
}
}
}
Err(err) =>
return Some(Response::init(err.into()))
Err(err) => return Some(Response::init(err.into())),
}
}
}
@ -386,44 +400,39 @@ impl<S: 'static> StartMiddlewares<S> {
// waiting for response
struct WaitingResponse<S> {
fut: Box<Future<Item=HttpResponse, Error=Error>>,
fut: Box<Future<Item = HttpResponse, Error = Error>>,
_s: PhantomData<S>,
}
impl<S: 'static> WaitingResponse<S> {
#[inline]
fn init(info: &mut ComposeInfo<S>, reply: Reply) -> ComposeState<S> {
match reply.into() {
ReplyItem::Message(resp) =>
RunMiddlewares::init(info, resp),
ReplyItem::Future(fut) =>
ComposeState::Handler(
WaitingResponse { fut, _s: PhantomData }),
ReplyItem::Message(resp) => RunMiddlewares::init(info, resp),
ReplyItem::Future(fut) => ComposeState::Handler(WaitingResponse {
fut,
_s: PhantomData,
}),
}
}
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
match self.fut.poll() {
Ok(Async::NotReady) => None,
Ok(Async::Ready(response)) =>
Some(RunMiddlewares::init(info, response)),
Err(err) =>
Some(Response::init(err.into())),
Ok(Async::Ready(response)) => Some(RunMiddlewares::init(info, response)),
Err(err) => Some(Response::init(err.into())),
}
}
}
/// Middlewares response executor
struct RunMiddlewares<S> {
curr: usize,
fut: Option<Box<Future<Item=HttpResponse, Error=Error>>>,
fut: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
_s: PhantomData<S>,
}
impl<S: 'static> RunMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> {
let mut curr = 0;
let len = info.mws.len();
@ -432,40 +441,39 @@ impl<S: 'static> RunMiddlewares<S> {
resp = match info.mws[curr].response(&mut info.req, resp) {
Err(err) => {
info.count = curr + 1;
return Response::init(err.into())
},
return Response::init(err.into());
}
Ok(MiddlewareResponse::Done(r)) => {
curr += 1;
if curr == len {
return Response::init(r)
return Response::init(r);
} else {
r
}
},
}
Ok(MiddlewareResponse::Future(fut)) => {
return ComposeState::RunMiddlewares(
RunMiddlewares { curr, fut: Some(fut), _s: PhantomData })
},
return ComposeState::RunMiddlewares(RunMiddlewares {
curr,
fut: Some(fut),
_s: PhantomData,
})
}
};
}
}
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>>
{
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
let len = info.mws.len();
loop {
// poll latest fut
let mut resp = match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => {
return None
}
Ok(Async::NotReady) => return None,
Ok(Async::Ready(resp)) => {
self.curr += 1;
resp
}
Err(err) =>
return Some(Response::init(err.into())),
Err(err) => return Some(Response::init(err.into())),
};
loop {
@ -473,16 +481,15 @@ impl<S: 'static> RunMiddlewares<S> {
return Some(Response::init(resp));
} else {
match info.mws[self.curr].response(&mut info.req, resp) {
Err(err) =>
return Some(Response::init(err.into())),
Err(err) => return Some(Response::init(err.into())),
Ok(MiddlewareResponse::Done(r)) => {
self.curr += 1;
resp = r
},
}
Ok(MiddlewareResponse::Future(fut)) => {
self.fut = Some(fut);
break
},
break;
}
}
}
}
@ -496,9 +503,10 @@ struct Response<S> {
}
impl<S: 'static> Response<S> {
fn init(resp: HttpResponse) -> ComposeState<S> {
ComposeState::Response(
Response{resp: Some(resp), _s: PhantomData})
ComposeState::Response(Response {
resp: Some(resp),
_s: PhantomData,
})
}
}

View File

@ -1,15 +1,15 @@
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem;
use std::rc::Rc;
use std::hash::{Hash, Hasher};
use std::collections::HashMap;
use regex::{Regex, escape};
use percent_encoding::percent_decode;
use regex::{escape, Regex};
use param::Params;
use error::UrlGenerationError;
use resource::ResourceHandler;
use httprequest::HttpRequest;
use param::Params;
use resource::ResourceHandler;
use server::ServerSettings;
/// Interface for application router.
@ -25,11 +25,10 @@ struct Inner {
impl Router {
/// Create new router
pub fn new<S>(prefix: &str,
settings: ServerSettings,
map: Vec<(Resource, Option<ResourceHandler<S>>)>)
-> (Router, Vec<ResourceHandler<S>>)
{
pub fn new<S>(
prefix: &str, settings: ServerSettings,
map: Vec<(Resource, Option<ResourceHandler<S>>)>,
) -> (Router, Vec<ResourceHandler<S>>) {
let prefix = prefix.trim().trim_right_matches('/').to_owned();
let mut named = HashMap::new();
let mut patterns = Vec::new();
@ -48,8 +47,16 @@ impl Router {
}
let prefix_len = prefix.len();
(Router(Rc::new(
Inner{ prefix, prefix_len, named, patterns, srv: settings })), resources)
(
Router(Rc::new(Inner {
prefix,
prefix_len,
named,
patterns,
srv: settings,
})),
resources,
)
}
/// Router prefix
@ -71,16 +78,18 @@ impl Router {
/// Query for matched resource
pub fn recognize<S>(&self, req: &mut HttpRequest<S>) -> Option<usize> {
if self.0.prefix_len > req.path().len() {
return None
return None;
}
let path: &str = unsafe{mem::transmute(&req.path()[self.0.prefix_len..])};
let path: &str = unsafe { mem::transmute(&req.path()[self.0.prefix_len..]) };
let route_path = if path.is_empty() { "/" } else { path };
let p = percent_decode(route_path.as_bytes()).decode_utf8().unwrap();
let p = percent_decode(route_path.as_bytes())
.decode_utf8()
.unwrap();
for (idx, pattern) in self.0.patterns.iter().enumerate() {
if pattern.match_with_params(p.as_ref(), req.match_info_mut()) {
req.set_resource(idx);
return Some(idx)
return Some(idx);
}
}
None
@ -97,7 +106,7 @@ impl Router {
for pattern in &self.0.patterns {
if pattern.is_match(path) {
return true
return true;
}
}
false
@ -105,12 +114,14 @@ impl Router {
/// Build named resource path.
///
/// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method.url_for)
/// for detailed information.
pub fn resource_path<U, I>(&self, name: &str, elements: U)
-> Result<String, UrlGenerationError>
where U: IntoIterator<Item=I>,
I: AsRef<str>,
/// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method.
/// url_for) for detailed information.
pub fn resource_path<U, I>(
&self, name: &str, elements: U
) -> Result<String, UrlGenerationError>
where
U: IntoIterator<Item = I>,
I: AsRef<str>,
{
if let Some(pattern) = self.0.named.get(name) {
pattern.0.resource_path(self, elements)
@ -196,7 +207,7 @@ impl Resource {
let tp = if is_dynamic {
let re = match Regex::new(&pattern) {
Ok(re) => re,
Err(err) => panic!("Wrong path pattern: \"{}\" {}", path, err)
Err(err) => panic!("Wrong path pattern: \"{}\" {}", path, err),
};
let names = re.capture_names()
.filter_map(|name| name.map(|name| name.to_owned()))
@ -237,9 +248,9 @@ impl Resource {
}
}
pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>)
-> bool
{
pub fn match_with_params<'a>(
&'a self, path: &'a str, params: &'a mut Params<'a>
) -> bool {
match self.tp {
PatternType::Static(ref s) => s == path,
PatternType::Dynamic(ref re, ref names) => {
@ -248,7 +259,7 @@ impl Resource {
for capture in captures.iter() {
if let Some(ref m) = capture {
if idx != 0 {
params.add(names[idx-1].as_str(), m.as_str());
params.add(names[idx - 1].as_str(), m.as_str());
}
idx += 1;
}
@ -262,10 +273,12 @@ impl Resource {
}
/// Build reousrce path.
pub fn resource_path<U, I>(&self, router: &Router, elements: U)
-> Result<String, UrlGenerationError>
where U: IntoIterator<Item=I>,
I: AsRef<str>,
pub fn resource_path<U, I>(
&self, router: &Router, elements: U
) -> Result<String, UrlGenerationError>
where
U: IntoIterator<Item = I>,
I: AsRef<str>,
{
let mut iter = elements.into_iter();
let mut path = if self.rtp != ResourceType::External {
@ -280,7 +293,7 @@ impl Resource {
if let Some(val) = iter.next() {
path.push_str(val.as_ref())
} else {
return Err(UrlGenerationError::NotEnoughElements)
return Err(UrlGenerationError::NotEnoughElements);
}
}
}
@ -374,20 +387,35 @@ mod tests {
#[test]
fn test_recognizer() {
let routes = vec![
(Resource::new("", "/name"),
Some(ResourceHandler::default())),
(Resource::new("", "/name/{val}"),
Some(ResourceHandler::default())),
(Resource::new("", "/name/{val}/index.html"),
Some(ResourceHandler::default())),
(Resource::new("", "/file/{file}.{ext}"),
Some(ResourceHandler::default())),
(Resource::new("", "/v{val}/{val2}/index.html"),
Some(ResourceHandler::default())),
(Resource::new("", "/v/{tail:.*}"),
Some(ResourceHandler::default())),
(Resource::new("", "{test}/index.html"),
Some(ResourceHandler::default()))];
(
Resource::new("", "/name"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/name/{val}"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/name/{val}/index.html"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/file/{file}.{ext}"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/v{val}/{val2}/index.html"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/v/{tail:.*}"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "{test}/index.html"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/name").finish();
@ -415,7 +443,10 @@ mod tests {
let mut req = TestRequest::with_uri("/v/blah-blah/index.html").finish();
assert_eq!(rec.recognize(&mut req), Some(5));
assert_eq!(req.match_info().get("tail").unwrap(), "blah-blah/index.html");
assert_eq!(
req.match_info().get("tail").unwrap(),
"blah-blah/index.html"
);
let mut req = TestRequest::with_uri("/bbb/index.html").finish();
assert_eq!(rec.recognize(&mut req), Some(6));
@ -425,8 +456,15 @@ mod tests {
#[test]
fn test_recognizer_2() {
let routes = vec![
(Resource::new("", "/index.json"), Some(ResourceHandler::default())),
(Resource::new("", "/{source}.json"), Some(ResourceHandler::default()))];
(
Resource::new("", "/index.json"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/{source}.json"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/index.json").finish();
@ -439,8 +477,15 @@ mod tests {
#[test]
fn test_recognizer_with_prefix() {
let routes = vec![
(Resource::new("", "/name"), Some(ResourceHandler::default())),
(Resource::new("", "/name/{val}"), Some(ResourceHandler::default()))];
(
Resource::new("", "/name"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/name/{val}"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("/test", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/name").finish();
@ -456,8 +501,15 @@ mod tests {
// same patterns
let routes = vec![
(Resource::new("", "/name"), Some(ResourceHandler::default())),
(Resource::new("", "/name/{val}"), Some(ResourceHandler::default()))];
(
Resource::new("", "/name"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/name/{val}"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("/test2", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/name").finish();
@ -525,18 +577,25 @@ mod tests {
#[test]
fn test_request_resource() {
let routes = vec![
(Resource::new("r1", "/index.json"), Some(ResourceHandler::default())),
(Resource::new("r2", "/test.json"), Some(ResourceHandler::default()))];
(
Resource::new("r1", "/index.json"),
Some(ResourceHandler::default()),
),
(
Resource::new("r2", "/test.json"),
Some(ResourceHandler::default()),
),
];
let (router, _) = Router::new::<()>("", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/index.json")
.finish_with_router(router.clone());
let mut req =
TestRequest::with_uri("/index.json").finish_with_router(router.clone());
assert_eq!(router.recognize(&mut req), Some(0));
let resource = req.resource();
assert_eq!(resource.name(), "r1");
let mut req = TestRequest::with_uri("/test.json")
.finish_with_router(router.clone());
let mut req =
TestRequest::with_uri("/test.json").finish_with_router(router.clone());
assert_eq!(router.recognize(&mut req), Some(1));
let resource = req.resource();
assert_eq!(resource.name(), "r2");

View File

@ -1,17 +1,16 @@
use std::{ptr, mem, time, io};
use std::net::{Shutdown, SocketAddr};
use std::rc::Rc;
use std::net::{SocketAddr, Shutdown};
use std::{io, mem, ptr, time};
use bytes::{Bytes, BytesMut, Buf, BufMut};
use futures::{Future, Poll, Async};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{Async, Future, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use super::{h1, h2, utils, HttpHandler, IoStream};
use super::settings::WorkerSettings;
use super::{utils, HttpHandler, IoStream, h1, h2};
const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0";
enum HttpProtocol<T: IoStream, H: 'static> {
H1(h1::Http1<T, H>),
H2(h2::Http2<T, H>),
@ -24,27 +23,47 @@ enum ProtocolKind {
}
#[doc(hidden)]
pub struct HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static {
pub struct HttpChannel<T, H>
where
T: IoStream,
H: HttpHandler + 'static,
{
proto: Option<HttpProtocol<T, H>>,
node: Option<Node<HttpChannel<T, H>>>,
}
impl<T, H> HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static
impl<T, H> HttpChannel<T, H>
where
T: IoStream,
H: HttpHandler + 'static,
{
pub(crate) fn new(settings: Rc<WorkerSettings<H>>,
mut io: T, peer: Option<SocketAddr>, http2: bool) -> HttpChannel<T, H>
{
pub(crate) fn new(
settings: Rc<WorkerSettings<H>>, mut io: T, peer: Option<SocketAddr>,
http2: bool,
) -> HttpChannel<T, H> {
settings.add_channel();
let _ = io.set_nodelay(true);
if http2 {
HttpChannel {
node: None, proto: Some(HttpProtocol::H2(
h2::Http2::new(settings, io, peer, Bytes::new()))) }
node: None,
proto: Some(HttpProtocol::H2(h2::Http2::new(
settings,
io,
peer,
Bytes::new(),
))),
}
} else {
HttpChannel {
node: None, proto: Some(HttpProtocol::Unknown(
settings, peer, io, BytesMut::with_capacity(4096))) }
node: None,
proto: Some(HttpProtocol::Unknown(
settings,
peer,
io,
BytesMut::with_capacity(4096),
)),
}
}
}
@ -55,15 +74,16 @@ impl<T, H> HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static
let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0)));
let _ = IoStream::shutdown(io, Shutdown::Both);
}
Some(HttpProtocol::H2(ref mut h2)) => {
h2.shutdown()
}
Some(HttpProtocol::H2(ref mut h2)) => h2.shutdown(),
_ => (),
}
}
}
impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static
impl<T, H> Future for HttpChannel<T, H>
where
T: IoStream,
H: HttpHandler + 'static,
{
type Item = ();
type Error = ();
@ -73,12 +93,15 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
let el = self as *mut _;
self.node = Some(Node::new(el));
let _ = match self.proto {
Some(HttpProtocol::H1(ref mut h1)) =>
self.node.as_ref().map(|n| h1.settings().head().insert(n)),
Some(HttpProtocol::H2(ref mut h2)) =>
self.node.as_ref().map(|n| h2.settings().head().insert(n)),
Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) =>
self.node.as_ref().map(|n| settings.head().insert(n)),
Some(HttpProtocol::H1(ref mut h1)) => self.node
.as_ref()
.map(|n| h1.settings().head().insert(n)),
Some(HttpProtocol::H2(ref mut h2)) => self.node
.as_ref()
.map(|n| h2.settings().head().insert(n)),
Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) => {
self.node.as_ref().map(|n| settings.head().insert(n))
}
None => unreachable!(),
};
}
@ -90,30 +113,35 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
Ok(Async::Ready(())) | Err(_) => {
h1.settings().remove_channel();
self.node.as_mut().map(|n| n.remove());
},
}
_ => (),
}
return result
},
return result;
}
Some(HttpProtocol::H2(ref mut h2)) => {
let result = h2.poll();
match result {
Ok(Async::Ready(())) | Err(_) => {
h2.settings().remove_channel();
self.node.as_mut().map(|n| n.remove());
},
}
_ => (),
}
return result
},
Some(HttpProtocol::Unknown(ref mut settings, _, ref mut io, ref mut buf)) => {
return result;
}
Some(HttpProtocol::Unknown(
ref mut settings,
_,
ref mut io,
ref mut buf,
)) => {
match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) | Err(_) => {
debug!("Ignored premature client disconnection");
settings.remove_channel();
self.node.as_mut().map(|n| n.remove());
return Err(())
},
return Err(());
}
_ => (),
}
@ -126,7 +154,7 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
} else {
return Ok(Async::NotReady);
}
},
}
None => unreachable!(),
};
@ -134,30 +162,36 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() {
match kind {
ProtocolKind::Http1 => {
self.proto = Some(
HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf)));
return self.poll()
},
self.proto = Some(HttpProtocol::H1(h1::Http1::new(
settings,
io,
addr,
buf,
)));
return self.poll();
}
ProtocolKind::Http2 => {
self.proto = Some(
HttpProtocol::H2(h2::Http2::new(settings, io, addr, buf.freeze())));
return self.poll()
},
self.proto = Some(HttpProtocol::H2(h2::Http2::new(
settings,
io,
addr,
buf.freeze(),
)));
return self.poll();
}
}
}
unreachable!()
}
}
pub(crate) struct Node<T>
{
pub(crate) struct Node<T> {
next: Option<*mut Node<()>>,
prev: Option<*mut Node<()>>,
element: *mut T,
}
impl<T> Node<T>
{
impl<T> Node<T> {
fn new(el: *mut T) -> Self {
Node {
next: None,
@ -194,9 +228,7 @@ impl<T> Node<T>
}
}
impl Node<()> {
pub(crate) fn head() -> Self {
Node {
next: None,
@ -205,7 +237,11 @@ impl Node<()> {
}
}
pub(crate) fn traverse<T, H>(&self) where T: IoStream, H: HttpHandler + 'static {
pub(crate) fn traverse<T, H>(&self)
where
T: IoStream,
H: HttpHandler + 'static,
{
let mut next = self.next.as_ref();
loop {
if let Some(n) = next {
@ -214,30 +250,39 @@ impl Node<()> {
next = n.next.as_ref();
if !n.element.is_null() {
let ch: &mut HttpChannel<T, H> = mem::transmute(
&mut *(n.element as *mut _));
let ch: &mut HttpChannel<T, H> =
mem::transmute(&mut *(n.element as *mut _));
ch.shutdown();
}
}
} else {
return
return;
}
}
}
}
/// Wrapper for `AsyncRead + AsyncWrite` types
pub(crate) struct WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
io: T,
pub(crate) struct WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
io: T,
}
impl<T> WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
impl<T> WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
pub fn new(io: T) -> Self {
WrapperStream{ io }
WrapperStream { io }
}
}
impl<T> IoStream for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
impl<T> IoStream for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline]
fn shutdown(&mut self, _: Shutdown) -> io::Result<()> {
Ok(())
@ -252,14 +297,20 @@ impl<T> IoStream for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static
}
}
impl<T> io::Read for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
impl<T> io::Read for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.io.read(buf)
}
}
impl<T> io::Write for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
impl<T> io::Write for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.write(buf)
@ -270,14 +321,20 @@ impl<T> io::Write for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static
}
}
impl<T> AsyncRead for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
impl<T> AsyncRead for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline]
fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.io.read_buf(buf)
}
}
impl<T> AsyncWrite for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static {
impl<T> AsyncWrite for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline]
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.io.shutdown()

View File

@ -1,25 +1,24 @@
use std::{io, cmp, mem};
use std::io::{Read, Write};
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::str::FromStr;
use std::{cmp, io, mem};
use bytes::{Bytes, BytesMut, BufMut};
use http::{Version, Method, HttpTryFrom};
use http::header::{HeaderMap, HeaderValue,
ACCEPT_ENCODING, CONNECTION,
CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
#[cfg(feature = "brotli")]
use brotli2::write::{BrotliDecoder, BrotliEncoder};
use bytes::{BufMut, Bytes, BytesMut};
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::{GzEncoder, DeflateDecoder, DeflateEncoder};
#[cfg(feature="brotli")]
use brotli2::write::{BrotliDecoder, BrotliEncoder};
use flate2::write::{DeflateDecoder, DeflateEncoder, GzEncoder};
use http::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING, CONNECTION,
CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{HttpTryFrom, Method, Version};
use header::ContentEncoding;
use body::{Body, Binary};
use body::{Binary, Body};
use error::PayloadError;
use header::ContentEncoding;
use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse;
use payload::{PayloadSender, PayloadWriter, PayloadStatus};
use payload::{PayloadSender, PayloadStatus, PayloadWriter};
use super::shared::SharedBytes;
@ -29,7 +28,6 @@ pub(crate) enum PayloadType {
}
impl PayloadType {
pub fn new(headers: &HeaderMap, sender: PayloadSender) -> PayloadType {
// check content-encoding
let enc = if let Some(enc) = headers.get(CONTENT_ENCODING) {
@ -43,8 +41,9 @@ impl PayloadType {
};
match enc {
ContentEncoding::Auto | ContentEncoding::Identity =>
PayloadType::Sender(sender),
ContentEncoding::Auto | ContentEncoding::Identity => {
PayloadType::Sender(sender)
}
_ => PayloadType::Encoding(Box::new(EncodedPayload::new(sender, enc))),
}
}
@ -84,7 +83,6 @@ impl PayloadWriter for PayloadType {
}
}
/// Payload wrapper with content decompression support
pub(crate) struct EncodedPayload {
inner: PayloadSender,
@ -94,12 +92,15 @@ pub(crate) struct EncodedPayload {
impl EncodedPayload {
pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload {
EncodedPayload{ inner, error: false, payload: PayloadStream::new(enc) }
EncodedPayload {
inner,
error: false,
payload: PayloadStream::new(enc),
}
}
}
impl PayloadWriter for EncodedPayload {
fn set_error(&mut self, err: PayloadError) {
self.inner.set_error(err)
}
@ -110,7 +111,7 @@ impl PayloadWriter for EncodedPayload {
Err(err) => {
self.error = true;
self.set_error(PayloadError::Io(err));
},
}
Ok(value) => {
if let Some(b) = value {
self.inner.feed_data(b);
@ -123,7 +124,7 @@ impl PayloadWriter for EncodedPayload {
fn feed_data(&mut self, data: Bytes) {
if self.error {
return
return;
}
match self.payload.feed_data(data) {
@ -145,7 +146,7 @@ impl PayloadWriter for EncodedPayload {
pub(crate) enum Decoder {
Deflate(Box<DeflateDecoder<Writer>>),
Gzip(Option<Box<GzDecoder<Wrapper>>>),
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
Br(Box<BrotliDecoder<Writer>>),
Identity,
}
@ -190,7 +191,9 @@ pub(crate) struct Writer {
impl Writer {
fn new() -> Writer {
Writer{buf: BytesMut::with_capacity(8192)}
Writer {
buf: BytesMut::with_capacity(8192),
}
}
fn take(&mut self) -> Bytes {
self.buf.take().freeze()
@ -216,65 +219,64 @@ pub(crate) struct PayloadStream {
impl PayloadStream {
pub fn new(enc: ContentEncoding) -> PayloadStream {
let dec = match enc {
#[cfg(feature="brotli")]
ContentEncoding::Br => Decoder::Br(
Box::new(BrotliDecoder::new(Writer::new()))),
ContentEncoding::Deflate => Decoder::Deflate(
Box::new(DeflateDecoder::new(Writer::new()))),
#[cfg(feature = "brotli")]
ContentEncoding::Br => {
Decoder::Br(Box::new(BrotliDecoder::new(Writer::new())))
}
ContentEncoding::Deflate => {
Decoder::Deflate(Box::new(DeflateDecoder::new(Writer::new())))
}
ContentEncoding::Gzip => Decoder::Gzip(None),
_ => Decoder::Identity,
};
PayloadStream{ decoder: dec, dst: BytesMut::new() }
PayloadStream {
decoder: dec,
dst: BytesMut::new(),
}
}
}
impl PayloadStream {
pub fn feed_eof(&mut self) -> io::Result<Option<Bytes>> {
match self.decoder {
#[cfg(feature="brotli")]
Decoder::Br(ref mut decoder) => {
match decoder.finish() {
Ok(mut writer) => {
let b = writer.take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
},
Err(e) => Err(e),
#[cfg(feature = "brotli")]
Decoder::Br(ref mut decoder) => match decoder.finish() {
Ok(mut writer) => {
let b = writer.take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
}
Err(e) => Err(e),
},
Decoder::Gzip(ref mut decoder) => {
if let Some(ref mut decoder) = *decoder {
decoder.as_mut().get_mut().eof = true;
self.dst.reserve(8192);
match decoder.read(unsafe{self.dst.bytes_mut()}) {
Ok(n) => {
unsafe{self.dst.advance_mut(n)};
return Ok(Some(self.dst.take().freeze()))
match decoder.read(unsafe { self.dst.bytes_mut() }) {
Ok(n) => {
unsafe { self.dst.advance_mut(n) };
return Ok(Some(self.dst.take().freeze()));
}
Err(e) =>
return Err(e),
Err(e) => return Err(e),
}
} else {
Ok(None)
}
},
Decoder::Deflate(ref mut decoder) => {
match decoder.try_finish() {
Ok(_) => {
let b = decoder.get_mut().take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
},
Err(e) => Err(e),
}
Decoder::Deflate(ref mut decoder) => match decoder.try_finish() {
Ok(_) => {
let b = decoder.get_mut().take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
}
Err(e) => Err(e),
},
Decoder::Identity => Ok(None),
}
@ -282,66 +284,67 @@ impl PayloadStream {
pub fn feed_data(&mut self, data: Bytes) -> io::Result<Option<Bytes>> {
match self.decoder {
#[cfg(feature="brotli")]
Decoder::Br(ref mut decoder) => {
match decoder.write_all(&data) {
Ok(_) => {
decoder.flush()?;
let b = decoder.get_mut().take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
},
Err(e) => Err(e)
#[cfg(feature = "brotli")]
Decoder::Br(ref mut decoder) => match decoder.write_all(&data) {
Ok(_) => {
decoder.flush()?;
let b = decoder.get_mut().take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
}
Err(e) => Err(e),
},
Decoder::Gzip(ref mut decoder) => {
if decoder.is_none() {
*decoder = Some(
Box::new(GzDecoder::new(
Wrapper{buf: BytesMut::from(data), eof: false})));
*decoder = Some(Box::new(GzDecoder::new(Wrapper {
buf: BytesMut::from(data),
eof: false,
})));
} else {
let _ = decoder.as_mut().unwrap().write(&data);
}
loop {
self.dst.reserve(8192);
match decoder.as_mut()
.as_mut().unwrap().read(unsafe{self.dst.bytes_mut()})
match decoder
.as_mut()
.as_mut()
.unwrap()
.read(unsafe { self.dst.bytes_mut() })
{
Ok(n) => {
Ok(n) => {
if n != 0 {
unsafe{self.dst.advance_mut(n)};
unsafe { self.dst.advance_mut(n) };
}
if n == 0 {
return Ok(Some(self.dst.take().freeze()));
}
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock && !self.dst.is_empty()
if e.kind() == io::ErrorKind::WouldBlock
&& !self.dst.is_empty()
{
return Ok(Some(self.dst.take().freeze()));
}
return Err(e)
return Err(e);
}
}
}
},
Decoder::Deflate(ref mut decoder) => {
match decoder.write_all(&data) {
Ok(_) => {
decoder.flush()?;
let b = decoder.get_mut().take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
},
Err(e) => Err(e),
}
Decoder::Deflate(ref mut decoder) => match decoder.write_all(&data) {
Ok(_) => {
decoder.flush()?;
let b = decoder.get_mut().take();
if !b.is_empty() {
Ok(Some(b))
} else {
Ok(None)
}
}
Err(e) => Err(e),
},
Decoder::Identity => Ok(Some(data)),
}
@ -351,33 +354,33 @@ impl PayloadStream {
pub(crate) enum ContentEncoder {
Deflate(DeflateEncoder<TransferEncoding>),
Gzip(GzEncoder<TransferEncoding>),
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
Br(BrotliEncoder<TransferEncoding>),
Identity(TransferEncoding),
}
impl ContentEncoder {
pub fn empty(bytes: SharedBytes) -> ContentEncoder {
ContentEncoder::Identity(TransferEncoding::eof(bytes))
}
pub fn for_server(buf: SharedBytes,
req: &HttpInnerMessage,
resp: &mut HttpResponse,
response_encoding: ContentEncoding) -> ContentEncoder
{
pub fn for_server(
buf: SharedBytes, req: &HttpInnerMessage, resp: &mut HttpResponse,
response_encoding: ContentEncoding,
) -> ContentEncoder {
let version = resp.version().unwrap_or_else(|| req.version);
let is_head = req.method == Method::HEAD;
let mut body = resp.replace_body(Body::Empty);
let has_body = match body {
Body::Empty => false,
Body::Binary(ref bin) =>
!(response_encoding == ContentEncoding::Auto && bin.len() < 96),
Body::Binary(ref bin) => {
!(response_encoding == ContentEncoding::Auto && bin.len() < 96)
}
_ => true,
};
// Enable content encoding only if response does not contain Content-Encoding header
// Enable content encoding only if response does not contain Content-Encoding
// header
let mut encoding = if has_body {
let encoding = match response_encoding {
ContentEncoding::Auto => {
@ -396,7 +399,9 @@ impl ContentEncoder {
};
if encoding.is_compression() {
resp.headers_mut().insert(
CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str()));
CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
}
encoding
} else {
@ -409,23 +414,27 @@ impl ContentEncoder {
resp.headers_mut().remove(CONTENT_LENGTH);
}
TransferEncoding::length(0, buf)
},
}
Body::Binary(ref mut bytes) => {
if !(encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto)
|| encoding == ContentEncoding::Auto)
{
let tmp = SharedBytes::default();
let transfer = TransferEncoding::eof(tmp.clone());
let mut enc = match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate(
DeflateEncoder::new(transfer, Compression::fast())),
ContentEncoding::Gzip => ContentEncoder::Gzip(
GzEncoder::new(transfer, Compression::fast())),
#[cfg(feature="brotli")]
ContentEncoding::Br => ContentEncoder::Br(
BrotliEncoder::new(transfer, 3)),
DeflateEncoder::new(transfer, Compression::fast()),
),
ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new(
transfer,
Compression::fast(),
)),
#[cfg(feature = "brotli")]
ContentEncoding::Br => {
ContentEncoder::Br(BrotliEncoder::new(transfer, 3))
}
ContentEncoding::Identity => ContentEncoder::Identity(transfer),
ContentEncoding::Auto => unreachable!()
ContentEncoding::Auto => unreachable!(),
};
// TODO return error!
let _ = enc.write(bytes.clone());
@ -438,7 +447,9 @@ impl ContentEncoder {
let mut b = BytesMut::new();
let _ = write!(b, "{}", bytes.len());
resp.headers_mut().insert(
CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
CONTENT_LENGTH,
HeaderValue::try_from(b.freeze()).unwrap(),
);
} else {
// resp.headers_mut().remove(CONTENT_LENGTH);
}
@ -449,8 +460,8 @@ impl ContentEncoder {
if version == Version::HTTP_2 {
error!("Connection upgrade is forbidden for HTTP/2");
} else {
resp.headers_mut().insert(
CONNECTION, HeaderValue::from_static("upgrade"));
resp.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
if encoding != ContentEncoding::Identity {
encoding = ContentEncoding::Identity;
@ -470,20 +481,24 @@ impl ContentEncoder {
}
match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate(
DeflateEncoder::new(transfer, Compression::fast())),
ContentEncoding::Gzip => ContentEncoder::Gzip(
GzEncoder::new(transfer, Compression::fast())),
#[cfg(feature="brotli")]
ContentEncoding::Br => ContentEncoder::Br(
BrotliEncoder::new(transfer, 3)),
ContentEncoding::Identity | ContentEncoding::Auto =>
ContentEncoder::Identity(transfer),
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer,
Compression::fast(),
)),
ContentEncoding::Gzip => {
ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::fast()))
}
#[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
}
}
}
fn streaming_encoding(buf: SharedBytes, version: Version,
resp: &mut HttpResponse) -> TransferEncoding {
fn streaming_encoding(
buf: SharedBytes, version: Version, resp: &mut HttpResponse
) -> TransferEncoding {
match resp.chunked() {
Some(true) => {
// Enable transfer encoding
@ -492,13 +507,12 @@ impl ContentEncoder {
resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf)
} else {
resp.headers_mut().insert(
TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
resp.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked(buf)
}
},
Some(false) =>
TransferEncoding::eof(buf),
}
Some(false) => TransferEncoding::eof(buf),
None => {
// if Content-Length is specified, then use it as length hint
let (len, chunked) =
@ -530,9 +544,11 @@ impl ContentEncoder {
match version {
Version::HTTP_11 => {
resp.headers_mut().insert(
TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TRANSFER_ENCODING,
HeaderValue::from_static("chunked"),
);
TransferEncoding::chunked(buf)
},
}
_ => {
resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf)
@ -545,11 +561,10 @@ impl ContentEncoder {
}
impl ContentEncoder {
#[inline]
pub fn is_eof(&self) -> bool {
match *self {
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(),
@ -561,39 +576,35 @@ impl ContentEncoder {
#[inline(always)]
pub fn write_eof(&mut self) -> Result<(), io::Error> {
let encoder = mem::replace(
self, ContentEncoder::Identity(TransferEncoding::eof(SharedBytes::empty())));
self,
ContentEncoder::Identity(TransferEncoding::eof(SharedBytes::empty())),
);
match encoder {
#[cfg(feature="brotli")]
ContentEncoder::Br(encoder) => {
match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
},
Err(err) => Err(err),
}
}
ContentEncoder::Gzip(encoder) => {
match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
},
Err(err) => Err(err),
#[cfg(feature = "brotli")]
ContentEncoder::Br(encoder) => match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
}
Err(err) => Err(err),
},
ContentEncoder::Deflate(encoder) => {
match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
},
Err(err) => Err(err),
ContentEncoder::Gzip(encoder) => match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
}
Err(err) => Err(err),
},
ContentEncoder::Deflate(encoder) => match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
}
Err(err) => Err(err),
},
ContentEncoder::Identity(mut writer) => {
writer.encode_eof();
@ -607,23 +618,23 @@ impl ContentEncoder {
#[inline(always)]
pub fn write(&mut self, data: Binary) -> Result<(), io::Error> {
match *self {
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => {
match encoder.write_all(data.as_ref()) {
Ok(_) => Ok(()),
Err(err) => {
trace!("Error decoding br encoding: {}", err);
Err(err)
},
}
}
},
}
ContentEncoder::Gzip(ref mut encoder) => {
match encoder.write_all(data.as_ref()) {
Ok(_) => Ok(()),
Err(err) => {
trace!("Error decoding gzip encoding: {}", err);
Err(err)
},
}
}
}
ContentEncoder::Deflate(ref mut encoder) => {
@ -632,7 +643,7 @@ impl ContentEncoder {
Err(err) => {
trace!("Error decoding deflate encoding: {}", err);
Err(err)
},
}
}
}
ContentEncoder::Identity(ref mut encoder) => {
@ -665,7 +676,6 @@ enum TransferEncodingKind {
}
impl TransferEncoding {
#[inline]
pub fn eof(bytes: SharedBytes) -> TransferEncoding {
TransferEncoding {
@ -707,7 +717,7 @@ impl TransferEncoding {
let eof = msg.is_empty();
self.buffer.extend(msg);
Ok(eof)
},
}
TransferEncodingKind::Chunked(ref mut eof) => {
if *eof {
return Ok(true);
@ -726,21 +736,22 @@ impl TransferEncoding {
self.buffer.extend_from_slice(b"\r\n");
}
Ok(*eof)
},
}
TransferEncodingKind::Length(ref mut remaining) => {
if *remaining > 0 {
if msg.is_empty() {
return Ok(*remaining == 0)
return Ok(*remaining == 0);
}
let len = cmp::min(*remaining, msg.len() as u64);
self.buffer.extend(msg.take().split_to(len as usize).into());
self.buffer
.extend(msg.take().split_to(len as usize).into());
*remaining -= len as u64;
Ok(*remaining == 0)
} else {
Ok(true)
}
},
}
}
}
@ -754,13 +765,12 @@ impl TransferEncoding {
*eof = true;
self.buffer.extend_from_slice(b"0\r\n\r\n");
}
},
}
}
}
}
impl io::Write for TransferEncoding {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.encode(Binary::from_slice(buf))?;
@ -773,7 +783,6 @@ impl io::Write for TransferEncoding {
}
}
struct AcceptEncoding {
encoding: ContentEncoding,
quality: f64,
@ -817,27 +826,31 @@ impl AcceptEncoding {
_ => match f64::from_str(parts[1]) {
Ok(q) => q,
Err(_) => 0.0,
}
},
};
Some(AcceptEncoding{ encoding, quality })
Some(AcceptEncoding {
encoding,
quality,
})
}
/// Parse a raw Accept-Encoding header value into an ordered list.
pub fn parse(raw: &str) -> ContentEncoding {
let mut encodings: Vec<_> =
raw.replace(' ', "").split(',').map(|l| AcceptEncoding::new(l)).collect();
let mut encodings: Vec<_> = raw.replace(' ', "")
.split(',')
.map(|l| AcceptEncoding::new(l))
.collect();
encodings.sort();
for enc in encodings {
if let Some(enc) = enc {
return enc.encoding
return enc.encoding;
}
}
ContentEncoding::Identity
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -846,9 +859,13 @@ mod tests {
fn test_chunked_te() {
let bytes = SharedBytes::default();
let mut enc = TransferEncoding::chunked(bytes.clone());
assert!(!enc.encode(Binary::from(b"test".as_ref())).ok().unwrap());
assert!(!enc.encode(Binary::from(b"test".as_ref()))
.ok()
.unwrap());
assert!(enc.encode(Binary::from(b"".as_ref())).ok().unwrap());
assert_eq!(bytes.get_mut().take().freeze(),
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n"));
assert_eq!(
bytes.get_mut().take().freeze(),
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +1,22 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::{io, mem};
use std::rc::Rc;
use bytes::BufMut;
use futures::{Async, Poll};
use tokio_io::AsyncWrite;
use http::{Method, Version};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE};
use http::{Method, Version};
use std::rc::Rc;
use std::{io, mem};
use tokio_io::AsyncWrite;
use body::{Body, Binary};
use super::encoding::ContentEncoder;
use super::helpers;
use super::settings::WorkerSettings;
use super::shared::SharedBytes;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
use body::{Binary, Body};
use header::ContentEncoding;
use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse;
use super::helpers;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
use super::shared::SharedBytes;
use super::encoding::ContentEncoder;
use super::settings::WorkerSettings;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
@ -41,10 +41,9 @@ pub(crate) struct H1Writer<T: AsyncWrite, H: 'static> {
}
impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
pub fn new(stream: T, buf: SharedBytes, settings: Rc<WorkerSettings<H>>)
-> H1Writer<T, H>
{
pub fn new(
stream: T, buf: SharedBytes, settings: Rc<WorkerSettings<H>>
) -> H1Writer<T, H> {
H1Writer {
flags: Flags::empty(),
encoder: ContentEncoder::empty(buf.clone()),
@ -80,11 +79,11 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
match self.stream.write(&data[written..]) {
Ok(0) => {
self.disconnected();
return Err(io::Error::new(io::ErrorKind::WriteZero, ""))
},
return Err(io::Error::new(io::ErrorKind::WriteZero, ""));
}
Ok(n) => {
written += n;
},
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(written)
}
@ -96,19 +95,18 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
}
impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline]
fn written(&self) -> u64 {
self.written
}
fn start(&mut self,
req: &mut HttpInnerMessage,
msg: &mut HttpResponse,
encoding: ContentEncoding) -> io::Result<WriterState>
{
fn start(
&mut self, req: &mut HttpInnerMessage, msg: &mut HttpResponse,
encoding: ContentEncoding,
) -> io::Result<WriterState> {
// prepare task
self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding);
self.encoder =
ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding);
if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) {
self.flags.insert(Flags::STARTED | Flags::KEEPALIVE);
} else {
@ -119,15 +117,18 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
let version = msg.version().unwrap_or_else(|| req.version);
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade"));
msg.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
// keep-alive
else if self.flags.contains(Flags::KEEPALIVE) {
if version < Version::HTTP_11 {
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("keep-alive"));
msg.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
}
} else if version >= Version::HTTP_11 {
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("close"));
msg.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("close"));
}
let body = msg.replace_body(Body::Empty);
@ -137,12 +138,14 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
let reason = msg.reason().as_bytes();
let mut is_bin = if let Body::Binary(ref bytes) = body {
buffer.reserve(
256 + msg.headers().len() * AVERAGE_HEADER_SIZE
+ bytes.len() + reason.len());
256 + msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len()
+ reason.len(),
);
true
} else {
buffer.reserve(
256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len());
256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len(),
);
false
};
@ -151,51 +154,50 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
SharedBytes::extend_from_slice_(buffer, reason);
match body {
Body::Empty =>
if req.method != Method::HEAD {
SharedBytes::put_slice(buffer, b"\r\ncontent-length: 0\r\n");
} else {
SharedBytes::put_slice(buffer, b"\r\n");
},
Body::Binary(ref bytes) =>
helpers::write_content_length(bytes.len(), &mut buffer),
_ =>
SharedBytes::put_slice(buffer, b"\r\n"),
Body::Empty => if req.method != Method::HEAD {
SharedBytes::put_slice(buffer, b"\r\ncontent-length: 0\r\n");
} else {
SharedBytes::put_slice(buffer, b"\r\n");
},
Body::Binary(ref bytes) => {
helpers::write_content_length(bytes.len(), &mut buffer)
}
_ => SharedBytes::put_slice(buffer, b"\r\n"),
}
// write headers
let mut pos = 0;
let mut has_date = false;
let mut remaining = buffer.remaining_mut();
let mut buf: &mut [u8] = unsafe{ mem::transmute(buffer.bytes_mut()) };
let mut buf: &mut [u8] = unsafe { mem::transmute(buffer.bytes_mut()) };
for (key, value) in msg.headers() {
if is_bin && key == CONTENT_LENGTH {
is_bin = false;
continue
continue;
}
has_date = has_date || key == DATE;
let v = value.as_ref();
let k = key.as_str().as_bytes();
let len = k.len() + v.len() + 4;
if len > remaining {
unsafe{buffer.advance_mut(pos)};
unsafe { buffer.advance_mut(pos) };
pos = 0;
buffer.reserve(len);
remaining = buffer.remaining_mut();
buf = unsafe{ mem::transmute(buffer.bytes_mut()) };
buf = unsafe { mem::transmute(buffer.bytes_mut()) };
}
buf[pos..pos+k.len()].copy_from_slice(k);
buf[pos..pos + k.len()].copy_from_slice(k);
pos += k.len();
buf[pos..pos+2].copy_from_slice(b": ");
buf[pos..pos + 2].copy_from_slice(b": ");
pos += 2;
buf[pos..pos+v.len()].copy_from_slice(v);
buf[pos..pos + v.len()].copy_from_slice(v);
pos += v.len();
buf[pos..pos+2].copy_from_slice(b"\r\n");
buf[pos..pos + 2].copy_from_slice(b"\r\n");
pos += 2;
remaining -= len;
}
unsafe{buffer.advance_mut(pos)};
unsafe { buffer.advance_mut(pos) };
// optimized date header, set_date writes \r\n
if !has_date {
@ -256,8 +258,10 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
self.encoder.write_eof()?;
if !self.encoder.is_eof() {
Err(io::Error::new(io::ErrorKind::Other,
"Last payload item, but eof is not reached"))
Err(io::Error::new(
io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
} else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
Ok(WriterState::Pause)
} else {
@ -268,11 +272,11 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline]
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
if !self.buffer.is_empty() {
let buf: &[u8] = unsafe{mem::transmute(self.buffer.as_ref())};
let buf: &[u8] = unsafe { mem::transmute(self.buffer.as_ref()) };
let written = self.write_data(buf)?;
let _ = self.buffer.split_to(written);
if self.buffer.len() > self.buffer_capacity {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
}
if shutdown {

View File

@ -1,30 +1,30 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::{io, cmp, mem};
use std::rc::Rc;
use std::io::{Read, Write};
use std::time::Duration;
use std::net::SocketAddr;
use std::collections::VecDeque;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::rc::Rc;
use std::time::Duration;
use std::{cmp, io, mem};
use actix::Arbiter;
use modhttp::request::Parts;
use http2::{Reason, RecvStream};
use http2::server::{self, Connection, Handshake, SendResponse};
use bytes::{Buf, Bytes};
use futures::{Async, Poll, Future, Stream};
use tokio_io::{AsyncRead, AsyncWrite};
use futures::{Async, Future, Poll, Stream};
use http2::server::{self, Connection, Handshake, SendResponse};
use http2::{Reason, RecvStream};
use modhttp::request::Parts;
use tokio_core::reactor::Timeout;
use tokio_io::{AsyncRead, AsyncWrite};
use pipeline::Pipeline;
use error::PayloadError;
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use payload::{Payload, PayloadWriter, PayloadStatus};
use payload::{Payload, PayloadStatus, PayloadWriter};
use pipeline::Pipeline;
use super::h2writer::H2Writer;
use super::encoding::PayloadType;
use super::h2writer::H2Writer;
use super::settings::WorkerSettings;
use super::{HttpHandler, HttpHandlerTask, Writer};
@ -35,9 +35,10 @@ bitflags! {
}
/// HTTP/2 Transport
pub(crate)
struct Http2<T, H>
where T: AsyncRead + AsyncWrite + 'static, H: 'static
pub(crate) struct Http2<T, H>
where
T: AsyncRead + AsyncWrite + 'static,
H: 'static,
{
flags: Flags,
settings: Rc<WorkerSettings<H>>,
@ -54,20 +55,23 @@ enum State<T: AsyncRead + AsyncWrite> {
}
impl<T, H> Http2<T, H>
where T: AsyncRead + AsyncWrite + 'static,
H: HttpHandler + 'static
where
T: AsyncRead + AsyncWrite + 'static,
H: HttpHandler + 'static,
{
pub fn new(settings: Rc<WorkerSettings<H>>,
io: T,
addr: Option<SocketAddr>, buf: Bytes) -> Self
{
Http2{ flags: Flags::empty(),
tasks: VecDeque::new(),
state: State::Handshake(
server::handshake(IoWrapper{unread: Some(buf), inner: io})),
keepalive_timer: None,
addr,
settings,
pub fn new(
settings: Rc<WorkerSettings<H>>, io: T, addr: Option<SocketAddr>, buf: Bytes
) -> Self {
Http2 {
flags: Flags::empty(),
tasks: VecDeque::new(),
state: State::Handshake(server::handshake(IoWrapper {
unread: Some(buf),
inner: io,
})),
keepalive_timer: None,
addr,
settings,
}
}
@ -89,7 +93,7 @@ impl<T, H> Http2<T, H>
match timeout.poll() {
Ok(Async::Ready(_)) => {
trace!("Keep-alive timeout, close connection");
return Ok(Async::Ready(()))
return Ok(Async::Ready(()));
}
Ok(Async::NotReady) => (),
Err(_) => unreachable!(),
@ -111,29 +115,30 @@ impl<T, H> Http2<T, H>
Ok(Async::Ready(ready)) => {
if ready {
item.flags.insert(
EntryFlags::EOF | EntryFlags::FINISHED);
EntryFlags::EOF | EntryFlags::FINISHED,
);
} else {
item.flags.insert(EntryFlags::EOF);
}
not_ready = false;
},
}
Ok(Async::NotReady) => {
if item.payload.need_read() == PayloadStatus::Read
&& !retry
{
continue
continue;
}
},
}
Err(err) => {
error!("Unhandled error: {}", err);
item.flags.insert(
EntryFlags::EOF |
EntryFlags::ERROR |
EntryFlags::WRITE_DONE);
EntryFlags::EOF | EntryFlags::ERROR
| EntryFlags::WRITE_DONE,
);
item.stream.reset(Reason::INTERNAL_ERROR);
}
}
break
break;
}
} else if !item.flags.contains(EntryFlags::FINISHED) {
match item.task.poll() {
@ -141,11 +146,12 @@ impl<T, H> Http2<T, H>
Ok(Async::Ready(_)) => {
not_ready = false;
item.flags.insert(EntryFlags::FINISHED);
},
}
Err(err) => {
item.flags.insert(
EntryFlags::ERROR | EntryFlags::WRITE_DONE |
EntryFlags::FINISHED);
EntryFlags::ERROR | EntryFlags::WRITE_DONE
| EntryFlags::FINISHED,
);
error!("Unhandled error: {}", err);
}
}
@ -167,13 +173,13 @@ impl<T, H> Http2<T, H>
// cleanup finished tasks
while !self.tasks.is_empty() {
if self.tasks[0].flags.contains(EntryFlags::EOF) &&
self.tasks[0].flags.contains(EntryFlags::WRITE_DONE) ||
self.tasks[0].flags.contains(EntryFlags::ERROR)
if self.tasks[0].flags.contains(EntryFlags::EOF)
&& self.tasks[0].flags.contains(EntryFlags::WRITE_DONE)
|| self.tasks[0].flags.contains(EntryFlags::ERROR)
{
self.tasks.pop_front();
} else {
break
break;
}
}
@ -186,7 +192,7 @@ impl<T, H> Http2<T, H>
for entry in &mut self.tasks {
entry.task.disconnected()
}
},
}
Ok(Async::Ready(Some((req, resp)))) => {
not_ready = false;
let (parts, body) = req.into_parts();
@ -194,8 +200,13 @@ impl<T, H> Http2<T, H>
// stop keepalive timer
self.keepalive_timer.take();
self.tasks.push_back(
Entry::new(parts, body, resp, self.addr, &self.settings));
self.tasks.push_back(Entry::new(
parts,
body,
resp,
self.addr,
&self.settings,
));
}
Ok(Async::NotReady) => {
// start keep-alive timer
@ -213,12 +224,13 @@ impl<T, H> Http2<T, H>
}
} else {
// keep-alive disable, drop connection
return conn.poll_close().map_err(
|e| error!("Error during connection close: {}", e))
return conn.poll_close().map_err(|e| {
error!("Error during connection close: {}", e)
});
}
} else {
// keep-alive unset, rely on operating system
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
}
Err(err) => {
@ -228,16 +240,17 @@ impl<T, H> Http2<T, H>
entry.task.disconnected()
}
self.keepalive_timer.take();
},
}
}
}
if not_ready {
if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) {
return conn.poll_close().map_err(
|e| error!("Error during connection close: {}", e))
if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED)
{
return conn.poll_close()
.map_err(|e| error!("Error during connection close: {}", e));
} else {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
}
}
@ -246,14 +259,11 @@ impl<T, H> Http2<T, H>
// handshake
self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() {
Ok(Async::Ready(conn)) => {
State::Connection(conn)
},
Ok(Async::NotReady) =>
return Ok(Async::NotReady),
Ok(Async::Ready(conn)) => State::Connection(conn),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
trace!("Error handling connection: {}", err);
return Err(())
return Err(());
}
}
} else {
@ -283,12 +293,12 @@ struct Entry<H: 'static> {
}
impl<H: 'static> Entry<H> {
fn new(parts: Parts,
recv: RecvStream,
resp: SendResponse<Bytes>,
addr: Option<SocketAddr>,
settings: &Rc<WorkerSettings<H>>) -> Entry<H>
where H: HttpHandler + 'static
fn new(
parts: Parts, recv: RecvStream, resp: SendResponse<Bytes>,
addr: Option<SocketAddr>, settings: &Rc<WorkerSettings<H>>,
) -> Entry<H>
where
H: HttpHandler + 'static,
{
// Payload and Content-Encoding
let (psender, payload) = Payload::new(false);
@ -312,18 +322,22 @@ impl<H: 'static> Entry<H> {
req = match h.handle(req) {
Ok(t) => {
task = Some(t);
break
},
break;
}
Err(req) => req,
}
}
Entry {task: task.unwrap_or_else(|| Pipeline::error(HttpResponse::NotFound())),
payload: psender,
stream: H2Writer::new(
resp, settings.get_shared_bytes(), Rc::clone(settings)),
flags: EntryFlags::empty(),
recv,
Entry {
task: task.unwrap_or_else(|| Pipeline::error(HttpResponse::NotFound())),
payload: psender,
stream: H2Writer::new(
resp,
settings.get_shared_bytes(),
Rc::clone(settings),
),
flags: EntryFlags::empty(),
recv,
}
}
@ -340,14 +354,12 @@ impl<H: 'static> Entry<H> {
match self.recv.poll() {
Ok(Async::Ready(Some(chunk))) => {
self.payload.feed_data(chunk);
},
}
Ok(Async::Ready(None)) => {
self.flags.insert(EntryFlags::REOF);
},
Ok(Async::NotReady) => (),
Err(err) => {
self.payload.set_error(PayloadError::Http2(err))
}
Ok(Async::NotReady) => (),
Err(err) => self.payload.set_error(PayloadError::Http2(err)),
}
}
}

View File

@ -1,25 +1,25 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::{io, cmp};
use std::rc::Rc;
use bytes::{Bytes, BytesMut};
use futures::{Async, Poll};
use http2::{Reason, SendStream};
use http2::server::SendResponse;
use http2::{Reason, SendStream};
use modhttp::Response;
use std::rc::Rc;
use std::{cmp, io};
use http::{Version, HttpTryFrom};
use http::header::{HeaderValue, CONNECTION, TRANSFER_ENCODING, DATE, CONTENT_LENGTH};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{HttpTryFrom, Version};
use body::{Body, Binary};
use super::encoding::ContentEncoder;
use super::helpers;
use super::settings::WorkerSettings;
use super::shared::SharedBytes;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
use body::{Binary, Body};
use header::ContentEncoding;
use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse;
use super::helpers;
use super::encoding::ContentEncoder;
use super::shared::SharedBytes;
use super::settings::WorkerSettings;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
const CHUNK_SIZE: usize = 16_384;
@ -44,10 +44,9 @@ pub(crate) struct H2Writer<H: 'static> {
}
impl<H: 'static> H2Writer<H> {
pub fn new(respond: SendResponse<Bytes>,
buf: SharedBytes, settings: Rc<WorkerSettings<H>>) -> H2Writer<H>
{
pub fn new(
respond: SendResponse<Bytes>, buf: SharedBytes, settings: Rc<WorkerSettings<H>>
) -> H2Writer<H> {
H2Writer {
respond,
settings,
@ -68,19 +67,18 @@ impl<H: 'static> H2Writer<H> {
}
impl<H: 'static> Writer for H2Writer<H> {
fn written(&self) -> u64 {
self.written
}
fn start(&mut self,
req: &mut HttpInnerMessage,
msg: &mut HttpResponse,
encoding: ContentEncoding) -> io::Result<WriterState>
{
fn start(
&mut self, req: &mut HttpInnerMessage, msg: &mut HttpResponse,
encoding: ContentEncoding,
) -> io::Result<WriterState> {
// prepare response
self.flags.insert(Flags::STARTED);
self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding);
self.encoder =
ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding);
if let Body::Empty = *msg.body() {
self.flags.insert(Flags::EOF);
}
@ -93,7 +91,8 @@ impl<H: 'static> Writer for H2Writer<H> {
if !msg.headers().contains_key(DATE) {
let mut bytes = BytesMut::with_capacity(29);
self.settings.set_date_simple(&mut bytes);
msg.headers_mut().insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap());
msg.headers_mut()
.insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap());
}
let body = msg.replace_body(Body::Empty);
@ -104,11 +103,13 @@ impl<H: 'static> Writer for H2Writer<H> {
let l = val.len();
msg.headers_mut().insert(
CONTENT_LENGTH,
HeaderValue::try_from(val.split_to(l-2).freeze()).unwrap());
HeaderValue::try_from(val.split_to(l - 2).freeze()).unwrap(),
);
}
Body::Empty => {
msg.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
},
msg.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
}
_ => (),
}
@ -119,11 +120,11 @@ impl<H: 'static> Writer for H2Writer<H> {
resp.headers_mut().insert(key, value.clone());
}
match self.respond.send_response(resp, self.flags.contains(Flags::EOF)) {
Ok(stream) =>
self.stream = Some(stream),
Err(_) =>
return Err(io::Error::new(io::ErrorKind::Other, "err")),
match self.respond
.send_response(resp, self.flags.contains(Flags::EOF))
{
Ok(stream) => self.stream = Some(stream),
Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "err")),
}
trace!("Response: {:?}", msg);
@ -169,8 +170,10 @@ impl<H: 'static> Writer for H2Writer<H> {
self.flags.insert(Flags::EOF);
if !self.encoder.is_eof() {
Err(io::Error::new(io::ErrorKind::Other,
"Last payload item, but eof is not reached"))
Err(io::Error::new(
io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
} else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
Ok(WriterState::Pause)
} else {
@ -197,17 +200,18 @@ impl<H: 'static> Writer for H2Writer<H> {
Ok(Async::Ready(Some(cap))) => {
let len = self.buffer.len();
let bytes = self.buffer.split_to(cmp::min(cap, len));
let eof = self.buffer.is_empty() && self.flags.contains(Flags::EOF);
let eof =
self.buffer.is_empty() && self.flags.contains(Flags::EOF);
self.written += bytes.len() as u64;
if let Err(e) = stream.send_data(bytes.freeze(), eof) {
return Err(io::Error::new(io::ErrorKind::Other, e))
return Err(io::Error::new(io::ErrorKind::Other, e));
} else if !self.buffer.is_empty() {
let cap = cmp::min(self.buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap);
} else {
self.flags.remove(Flags::RESERVED);
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
}
Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),

View File

@ -1,9 +1,9 @@
use std::{mem, ptr, slice};
use std::cell::RefCell;
use std::rc::Rc;
use std::collections::VecDeque;
use bytes::{BufMut, BytesMut};
use http::Version;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;
use std::{mem, ptr, slice};
use httprequest::HttpInnerMessage;
@ -35,7 +35,9 @@ impl SharedMessagePool {
}
pub(crate) struct SharedHttpInnerMessage(
Option<Rc<HttpInnerMessage>>, Option<Rc<SharedMessagePool>>);
Option<Rc<HttpInnerMessage>>,
Option<Rc<SharedMessagePool>>,
);
impl Drop for SharedHttpInnerMessage {
fn drop(&mut self) {
@ -50,26 +52,25 @@ impl Drop for SharedHttpInnerMessage {
}
impl Clone for SharedHttpInnerMessage {
fn clone(&self) -> SharedHttpInnerMessage {
SharedHttpInnerMessage(self.0.clone(), self.1.clone())
}
}
impl Default for SharedHttpInnerMessage {
fn default() -> SharedHttpInnerMessage {
SharedHttpInnerMessage(Some(Rc::new(HttpInnerMessage::default())), None)
}
}
impl SharedHttpInnerMessage {
pub fn from_message(msg: HttpInnerMessage) -> SharedHttpInnerMessage {
SharedHttpInnerMessage(Some(Rc::new(msg)), None)
}
pub fn new(msg: Rc<HttpInnerMessage>, pool: Rc<SharedMessagePool>) -> SharedHttpInnerMessage {
pub fn new(
msg: Rc<HttpInnerMessage>, pool: Rc<SharedMessagePool>
) -> SharedHttpInnerMessage {
SharedHttpInnerMessage(Some(msg), Some(pool))
}
@ -78,7 +79,7 @@ impl SharedHttpInnerMessage {
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))]
pub fn get_mut(&self) -> &mut HttpInnerMessage {
let r: &HttpInnerMessage = self.0.as_ref().unwrap().as_ref();
unsafe{mem::transmute(r)}
unsafe { mem::transmute(r) }
}
#[inline(always)]
@ -88,20 +89,23 @@ impl SharedHttpInnerMessage {
}
}
const DEC_DIGITS_LUT: &[u8] =
b"0001020304050607080910111213141516171819\
const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\
2021222324252627282930313233343536373839\
4041424344454647484950515253545556575859\
6061626364656667686970717273747576777879\
8081828384858687888990919293949596979899";
pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) {
let mut buf: [u8; 13] = [b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1',
b' ', b' ', b' ', b' ', b' '];
let mut buf: [u8; 13] = [
b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', b' ', b' ', b' ', b' ', b' '
];
match version {
Version::HTTP_2 => buf[5] = b'2',
Version::HTTP_10 => buf[7] = b'0',
Version::HTTP_09 => {buf[5] = b'0'; buf[7] = b'9';},
Version::HTTP_09 => {
buf[5] = b'0';
buf[7] = b'9';
}
_ => (),
}
@ -124,7 +128,11 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM
} else {
let d1 = n << 1;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2);
ptr::copy_nonoverlapping(
lut_ptr.offset(d1 as isize),
buf_ptr.offset(curr),
2,
);
}
}
@ -137,30 +145,41 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM
/// NOTE: bytes object has to contain enough space
pub(crate) fn write_content_length(mut n: usize, bytes: &mut BytesMut) {
if n < 10 {
let mut buf: [u8; 21] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e',
b'n',b't',b'-',b'l',b'e',b'n',b'g',
b't',b'h',b':',b' ',b'0',b'\r',b'\n'];
let mut buf: [u8; 21] = [
b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e',
b'n', b'g', b't', b'h', b':', b' ', b'0', b'\r', b'\n',
];
buf[18] = (n as u8) + b'0';
bytes.put_slice(&buf);
} else if n < 100 {
let mut buf: [u8; 22] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e',
b'n',b't',b'-',b'l',b'e',b'n',b'g',
b't',b'h',b':',b' ',b'0',b'0',b'\r',b'\n'];
let mut buf: [u8; 22] = [
b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e',
b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'\r', b'\n',
];
let d1 = n << 1;
unsafe {
ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), buf.as_mut_ptr().offset(18), 2);
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize),
buf.as_mut_ptr().offset(18),
2,
);
}
bytes.put_slice(&buf);
} else if n < 1000 {
let mut buf: [u8; 23] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e',
b'n',b't',b'-',b'l',b'e',b'n',b'g',
b't',b'h',b':',b' ',b'0',b'0',b'0',b'\r',b'\n'];
let mut buf: [u8; 23] = [
b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e',
b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'0', b'\r', b'\n',
];
// decode 2 more chars, if > 2 chars
let d1 = (n % 100) << 1;
n /= 100;
unsafe {ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), buf.as_mut_ptr().offset(19), 2)};
unsafe {
ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize),
buf.as_mut_ptr().offset(19),
2,
)
};
// decode last 1
buf[18] = (n as u8) + b'0';
@ -216,12 +235,13 @@ pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) {
}
unsafe {
bytes.extend_from_slice(
slice::from_raw_parts(buf_ptr.offset(curr), 41 - curr as usize));
bytes.extend_from_slice(slice::from_raw_parts(
buf_ptr.offset(curr),
41 - curr as usize,
));
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -231,33 +251,63 @@ mod tests {
let mut bytes = BytesMut::new();
bytes.reserve(50);
write_content_length(0, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 0\r\n"[..]
);
bytes.reserve(50);
write_content_length(9, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 9\r\n"[..]
);
bytes.reserve(50);
write_content_length(10, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 10\r\n"[..]
);
bytes.reserve(50);
write_content_length(99, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 99\r\n"[..]
);
bytes.reserve(50);
write_content_length(100, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 100\r\n"[..]
);
bytes.reserve(50);
write_content_length(101, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 101\r\n"[..]
);
bytes.reserve(50);
write_content_length(998, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 998\r\n"[..]
);
bytes.reserve(50);
write_content_length(1000, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 1000\r\n"[..]
);
bytes.reserve(50);
write_content_length(1001, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 1001\r\n"[..]
);
bytes.reserve(50);
write_content_length(5909, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]);
assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 5909\r\n"[..]
);
}
}

View File

@ -1,27 +1,27 @@
//! Http server
use std::{time, io};
use std::net::Shutdown;
use std::{io, time};
use actix;
use futures::Poll;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_core::net::TcpStream;
use tokio_io::{AsyncRead, AsyncWrite};
mod srv;
mod worker;
mod channel;
pub(crate) mod encoding;
pub(crate) mod h1;
mod h2;
mod h1writer;
mod h2;
mod h2writer;
mod settings;
pub(crate) mod helpers;
mod settings;
pub(crate) mod shared;
mod srv;
pub(crate) mod utils;
mod worker;
pub use self::srv::HttpServer;
pub use self::settings::ServerSettings;
pub use self::srv::HttpServer;
use body::Binary;
use error::Error;
@ -56,9 +56,10 @@ pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536;
/// }
/// ```
pub fn new<F, U, H>(factory: F) -> HttpServer<H>
where F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item=H> + 'static,
H: IntoHttpHandler + 'static
where
F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item = H> + 'static,
H: IntoHttpHandler + 'static,
{
HttpServer::new(factory)
}
@ -107,7 +108,7 @@ pub struct ResumeServer;
///
/// If server starts with `spawn()` method, then spawned thread get terminated.
pub struct StopServer {
pub graceful: bool
pub graceful: bool,
}
impl actix::Message for StopServer {
@ -117,7 +118,6 @@ impl actix::Message for StopServer {
/// Low level http request handler
#[allow(unused_variables)]
pub trait HttpHandler: 'static {
/// Handle request
fn handle(&mut self, req: HttpRequest) -> Result<Box<HttpHandlerTask>, HttpRequest>;
}
@ -130,7 +130,6 @@ impl HttpHandler for Box<HttpHandler> {
#[doc(hidden)]
pub trait HttpHandlerTask {
/// Poll task, this method is used before or after *io* object is available
fn poll(&mut self) -> Poll<(), Error>;
@ -170,8 +169,10 @@ pub enum WriterState {
pub trait Writer {
fn written(&self) -> u64;
fn start(&mut self, req: &mut HttpInnerMessage, resp: &mut HttpResponse, encoding: ContentEncoding)
-> io::Result<WriterState>;
fn start(
&mut self, req: &mut HttpInnerMessage, resp: &mut HttpResponse,
encoding: ContentEncoding,
) -> io::Result<WriterState>;
fn write(&mut self, payload: Binary) -> io::Result<WriterState>;
@ -207,10 +208,10 @@ impl IoStream for TcpStream {
}
}
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
use tokio_openssl::SslStream;
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
impl IoStream for SslStream<TcpStream> {
#[inline]
fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {
@ -229,10 +230,10 @@ impl IoStream for SslStream<TcpStream> {
}
}
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
use tokio_tls::TlsStream;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
impl IoStream for TlsStream<TcpStream> {
#[inline]
fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {

View File

@ -1,19 +1,19 @@
use std::{fmt, mem, net};
use bytes::BytesMut;
use futures_cpupool::{Builder, CpuPool};
use http::StatusCode;
use std::cell::{Cell, RefCell, RefMut, UnsafeCell};
use std::fmt::Write;
use std::rc::Rc;
use std::sync::Arc;
use std::cell::{Cell, RefCell, RefMut, UnsafeCell};
use std::{fmt, mem, net};
use time;
use bytes::BytesMut;
use http::StatusCode;
use futures_cpupool::{Builder, CpuPool};
use super::helpers;
use super::KeepAlive;
use super::channel::Node;
use super::helpers;
use super::shared::{SharedBytes, SharedBytesPool};
use body::Body;
use httpresponse::{HttpResponse, HttpResponsePool, HttpResponseBuilder};
use httpresponse::{HttpResponse, HttpResponseBuilder, HttpResponsePool};
/// Various server settings
#[derive(Clone)]
@ -71,9 +71,9 @@ impl Default for ServerSettings {
impl ServerSettings {
/// Crate server settings instance
pub(crate) fn new(addr: Option<net::SocketAddr>, host: &Option<String>, secure: bool)
-> ServerSettings
{
pub(crate) fn new(
addr: Option<net::SocketAddr>, host: &Option<String>, secure: bool
) -> ServerSettings {
let host = if let Some(ref host) = *host {
host.clone()
} else if let Some(ref addr) = addr {
@ -83,7 +83,13 @@ impl ServerSettings {
};
let cpu_pool = Arc::new(InnerCpuPool::new());
let responses = HttpResponsePool::pool();
ServerSettings { addr, secure, host, cpu_pool, responses }
ServerSettings {
addr,
secure,
host,
cpu_pool,
responses,
}
}
/// Returns the socket address of the local half of this TCP connection
@ -112,12 +118,13 @@ impl ServerSettings {
}
#[inline]
pub(crate) fn get_response_builder(&self, status: StatusCode) -> HttpResponseBuilder {
pub(crate) fn get_response_builder(
&self, status: StatusCode
) -> HttpResponseBuilder {
HttpResponsePool::get_builder(&self.responses, status)
}
}
// "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29;
@ -141,7 +148,8 @@ impl<H> WorkerSettings<H> {
};
WorkerSettings {
keep_alive, ka_enabled,
keep_alive,
ka_enabled,
h: RefCell::new(h),
bytes: Rc::new(SharedBytesPool::new()),
messages: Rc::new(helpers::SharedMessagePool::new()),
@ -176,7 +184,10 @@ impl<H> WorkerSettings<H> {
}
pub fn get_http_message(&self) -> helpers::SharedHttpInnerMessage {
helpers::SharedHttpInnerMessage::new(self.messages.get(), Rc::clone(&self.messages))
helpers::SharedHttpInnerMessage::new(
self.messages.get(),
Rc::clone(&self.messages),
)
}
pub fn add_channel(&self) {
@ -186,26 +197,26 @@ impl<H> WorkerSettings<H> {
pub fn remove_channel(&self) {
let num = self.channels.get();
if num > 0 {
self.channels.set(num-1);
self.channels.set(num - 1);
} else {
error!("Number of removed channels is bigger than added channel. Bug in actix-web");
}
}
pub fn update_date(&self) {
unsafe{&mut *self.date.get()}.update();
unsafe { &mut *self.date.get() }.update();
}
pub fn set_date(&self, dst: &mut BytesMut) {
let mut buf: [u8; 39] = unsafe { mem::uninitialized() };
buf[..6].copy_from_slice(b"date: ");
buf[6..35].copy_from_slice(&(unsafe{&*self.date.get()}.bytes));
buf[6..35].copy_from_slice(&(unsafe { &*self.date.get() }.bytes));
buf[35..].copy_from_slice(b"\r\n\r\n");
dst.extend_from_slice(&buf);
}
pub fn set_date_simple(&self, dst: &mut BytesMut) {
dst.extend_from_slice(&(unsafe{&*self.date.get()}.bytes));
dst.extend_from_slice(&(unsafe { &*self.date.get() }.bytes));
}
}
@ -216,7 +227,10 @@ struct Date {
impl Date {
fn new() -> Date {
let mut date = Date{bytes: [0; DATE_VALUE_LENGTH], pos: 0};
let mut date = Date {
bytes: [0; DATE_VALUE_LENGTH],
pos: 0,
};
date.update();
date
}
@ -235,14 +249,16 @@ impl fmt::Write for Date {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_date_len() {
assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len());
assert_eq!(
DATE_VALUE_LENGTH,
"Sun, 06 Nov 1994 08:49:37 GMT".len()
);
}
#[test]

View File

@ -1,12 +1,11 @@
use std::{io, mem};
use std::cell::RefCell;
use std::rc::Rc;
use std::collections::VecDeque;
use bytes::{BufMut, BytesMut};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;
use std::{io, mem};
use body::Binary;
/// Internal use only! unsafe
#[derive(Debug)]
pub(crate) struct SharedBytesPool(RefCell<VecDeque<Rc<BytesMut>>>);
@ -34,8 +33,7 @@ impl SharedBytesPool {
}
#[derive(Debug)]
pub(crate) struct SharedBytes(
Option<Rc<BytesMut>>, Option<Rc<SharedBytesPool>>);
pub(crate) struct SharedBytes(Option<Rc<BytesMut>>, Option<Rc<SharedBytesPool>>);
impl Drop for SharedBytes {
fn drop(&mut self) {
@ -50,7 +48,6 @@ impl Drop for SharedBytes {
}
impl SharedBytes {
pub fn empty() -> Self {
SharedBytes(None, None)
}
@ -64,7 +61,7 @@ impl SharedBytes {
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))]
pub(crate) fn get_mut(&self) -> &mut BytesMut {
let r: &BytesMut = self.0.as_ref().unwrap().as_ref();
unsafe{mem::transmute(r)}
unsafe { mem::transmute(r) }
}
#[inline]

View File

@ -1,31 +1,33 @@
use std::{io, net, thread};
use std::rc::Rc;
use std::sync::{Arc, mpsc as sync_mpsc};
use std::sync::{mpsc as sync_mpsc, Arc};
use std::time::Duration;
use std::{io, net, thread};
use actix::prelude::*;
use actix::actors::signal;
use futures::{Future, Sink, Stream};
use actix::prelude::*;
use futures::sync::mpsc;
use tokio_io::{AsyncRead, AsyncWrite};
use futures::{Future, Sink, Stream};
use mio;
use num_cpus;
use net2::TcpBuilder;
use num_cpus;
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
use native_tls::TlsAcceptor;
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
use openssl::ssl::{AlpnError, SslAcceptorBuilder};
use super::channel::{HttpChannel, WrapperStream};
use super::settings::{ServerSettings, WorkerSettings};
use super::worker::{Conn, StopWorker, StreamHandlerType, Worker};
use super::{IntoHttpHandler, IoStream, KeepAlive};
use super::{PauseServer, ResumeServer, StopServer};
use super::channel::{HttpChannel, WrapperStream};
use super::worker::{Conn, Worker, StreamHandlerType, StopWorker};
use super::settings::{ServerSettings, WorkerSettings};
/// An HTTP Server
pub struct HttpServer<H> where H: IntoHttpHandler + 'static
pub struct HttpServer<H>
where
H: IntoHttpHandler + 'static,
{
h: Option<Rc<WorkerSettings<H::Handler>>>,
threads: usize,
@ -33,7 +35,7 @@ pub struct HttpServer<H> where H: IntoHttpHandler + 'static
host: Option<String>,
keep_alive: KeepAlive,
factory: Arc<Fn() -> Vec<H> + Send + Sync>,
#[cfg_attr(feature="cargo-clippy", allow(type_complexity))]
#[cfg_attr(feature = "cargo-clippy", allow(type_complexity))]
workers: Vec<(usize, Addr<Syn, Worker<H::Handler>>)>,
sockets: Vec<(net::SocketAddr, net::TcpListener)>,
accept: Vec<(mio::SetReadiness, sync_mpsc::Sender<Command>)>,
@ -44,8 +46,16 @@ pub struct HttpServer<H> where H: IntoHttpHandler + 'static
no_signals: bool,
}
unsafe impl<H> Sync for HttpServer<H> where H: IntoHttpHandler {}
unsafe impl<H> Send for HttpServer<H> where H: IntoHttpHandler {}
unsafe impl<H> Sync for HttpServer<H>
where
H: IntoHttpHandler,
{
}
unsafe impl<H> Send for HttpServer<H>
where
H: IntoHttpHandler,
{
}
#[derive(Clone)]
struct Info {
@ -57,41 +67,47 @@ enum ServerCommand {
WorkerDied(usize, Info),
}
impl<H> Actor for HttpServer<H> where H: IntoHttpHandler {
impl<H> Actor for HttpServer<H>
where
H: IntoHttpHandler,
{
type Context = Context<Self>;
}
impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
impl<H> HttpServer<H>
where
H: IntoHttpHandler + 'static,
{
/// Create new http server with application factory
pub fn new<F, U>(factory: F) -> Self
where F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item=H> + 'static,
where
F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item = H> + 'static,
{
let f = move || {
(factory)().into_iter().collect()
};
let f = move || (factory)().into_iter().collect();
HttpServer{ h: None,
threads: num_cpus::get(),
backlog: 2048,
host: None,
keep_alive: KeepAlive::Os,
factory: Arc::new(f),
workers: Vec::new(),
sockets: Vec::new(),
accept: Vec::new(),
exit: false,
shutdown_timeout: 30,
signals: None,
no_http2: false,
no_signals: false,
HttpServer {
h: None,
threads: num_cpus::get(),
backlog: 2048,
host: None,
keep_alive: KeepAlive::Os,
factory: Arc::new(f),
workers: Vec::new(),
sockets: Vec::new(),
accept: Vec::new(),
exit: false,
shutdown_timeout: 30,
signals: None,
no_http2: false,
no_signals: false,
}
}
/// Set number of workers to start.
///
/// By default http server uses number of available logical cpu as threads count.
/// By default http server uses number of available logical cpu as threads
/// count.
pub fn threads(mut self, num: usize) -> Self {
self.threads = num;
self
@ -101,7 +117,8 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
///
/// This refers to the number of clients that can be waiting to be served.
/// Exceeding this number results in the client getting an error when
/// attempting to connect. It should only affect servers under significant load.
/// attempting to connect. It should only affect servers under significant
/// load.
///
/// Generally set in the 64-2048 range. Default value is 2048.
///
@ -121,9 +138,9 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
/// Set server host name.
///
/// Host name is used by application router aa a hostname for url generation.
/// Check [ConnectionInfo](./dev/struct.ConnectionInfo.html#method.host) documentation
/// for more information.
/// Host name is used by application router aa a hostname for url
/// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo.
/// html#method.host) documentation for more information.
pub fn server_hostname(mut self, val: String) -> Self {
self.host = Some(val);
self
@ -152,8 +169,9 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
/// Timeout for graceful workers shutdown.
///
/// After receiving a stop signal, workers have this much time to finish serving requests.
/// Workers still alive after the timeout are force dropped.
/// After receiving a stop signal, workers have this much time to finish
/// serving requests. Workers still alive after the timeout are force
/// dropped.
///
/// By default shutdown timeout sets to 30 seconds.
pub fn shutdown_timeout(mut self, sec: u16) -> Self {
@ -192,7 +210,7 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
Ok(lst) => {
succ = true;
self.sockets.push((lst.local_addr().unwrap(), lst));
},
}
Err(e) => err = Some(e),
}
}
@ -201,16 +219,19 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
if let Some(e) = err.take() {
Err(e)
} else {
Err(io::Error::new(io::ErrorKind::Other, "Can not bind to address."))
Err(io::Error::new(
io::ErrorKind::Other,
"Can not bind to address.",
))
}
} else {
Ok(self)
}
}
fn start_workers(&mut self, settings: &ServerSettings, handler: &StreamHandlerType)
-> Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)>
{
fn start_workers(
&mut self, settings: &ServerSettings, handler: &StreamHandlerType
) -> Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)> {
// start workers
let mut workers = Vec::new();
for idx in 0..self.threads {
@ -223,7 +244,8 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
let addr = Arbiter::start(move |ctx: &mut Context<_>| {
let apps: Vec<_> = (*factory)()
.into_iter()
.map(|h| h.into_handler(s.clone())).collect();
.map(|h| h.into_handler(s.clone()))
.collect();
ctx.add_message_stream(rx);
Worker::new(apps, h, ka)
});
@ -248,12 +270,12 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
}
}
impl<H: IntoHttpHandler> HttpServer<H>
{
impl<H: IntoHttpHandler> HttpServer<H> {
/// Start listening for incoming connections.
///
/// This method starts number of http handler workers in separate threads.
/// For each address this method starts separate thread which does `accept()` in a loop.
/// For each address this method starts separate thread which does
/// `accept()` in a loop.
///
/// This methods panics if no socket addresses get bound.
///
@ -277,8 +299,7 @@ impl<H: IntoHttpHandler> HttpServer<H>
/// let _ = sys.run(); // <- Run actix system, this method actually starts all async processes
/// }
/// ```
pub fn start(mut self) -> Addr<Syn, Self>
{
pub fn start(mut self) -> Addr<Syn, Self> {
if self.sockets.is_empty() {
panic!("HttpServer::bind() has to be called before start()");
} else {
@ -287,15 +308,22 @@ impl<H: IntoHttpHandler> HttpServer<H>
self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers(&settings, &StreamHandlerType::Normal);
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Normal};
let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Normal,
};
// start acceptors threads
for (addr, sock) in addrs {
info!("Starting server on http://{}", addr);
self.accept.push(
start_accept_thread(
sock, addr, self.backlog,
tx.clone(), info.clone(), workers.clone()));
self.accept.push(start_accept_thread(
sock,
addr,
self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
}
// start http server actor
@ -304,16 +332,17 @@ impl<H: IntoHttpHandler> HttpServer<H>
ctx.add_stream(rx);
self
});
signals.map(|signals| signals.do_send(
signal::Subscribe(addr.clone().recipient())));
signals.map(|signals| {
signals.do_send(signal::Subscribe(addr.clone().recipient()))
});
addr
}
}
/// Spawn new thread and start listening for incoming connections.
///
/// This method spawns new thread and starts new actix system. Other than that it is
/// similar to `start()` method. This method blocks.
/// This method spawns new thread and starts new actix system. Other than
/// that it is similar to `start()` method. This method blocks.
///
/// This methods panics if no socket addresses get bound.
///
@ -344,28 +373,38 @@ impl<H: IntoHttpHandler> HttpServer<H>
}
}
#[cfg(feature="tls")]
impl<H: IntoHttpHandler> HttpServer<H>
{
#[cfg(feature = "tls")]
impl<H: IntoHttpHandler> HttpServer<H> {
/// Start listening for incoming tls connections.
pub fn start_tls(mut self, acceptor: TlsAcceptor) -> io::Result<Addr<Syn, Self>> {
if self.sockets.is_empty() {
Err(io::Error::new(io::ErrorKind::Other, "No socket addresses are bound"))
Err(io::Error::new(
io::ErrorKind::Other,
"No socket addresses are bound",
))
} else {
let (tx, rx) = mpsc::unbounded();
let addrs: Vec<(net::SocketAddr, net::TcpListener)> = self.sockets.drain(..).collect();
let addrs: Vec<(net::SocketAddr, net::TcpListener)> =
self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers(
&settings, &StreamHandlerType::Tls(acceptor.clone()));
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Tls(acceptor)};
let workers =
self.start_workers(&settings, &StreamHandlerType::Tls(acceptor.clone()));
let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Tls(acceptor),
};
// start acceptors threads
for (addr, sock) in addrs {
info!("Starting server on https://{}", addr);
self.accept.push(
start_accept_thread(
sock, addr, self.backlog,
tx.clone(), info.clone(), workers.clone()));
self.accept.push(start_accept_thread(
sock,
addr,
self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
}
// start http server actor
@ -374,23 +413,27 @@ impl<H: IntoHttpHandler> HttpServer<H>
ctx.add_stream(rx);
self
});
signals.map(|signals| signals.do_send(
signal::Subscribe(addr.clone().recipient())));
signals.map(|signals| {
signals.do_send(signal::Subscribe(addr.clone().recipient()))
});
Ok(addr)
}
}
}
#[cfg(feature="alpn")]
impl<H: IntoHttpHandler> HttpServer<H>
{
#[cfg(feature = "alpn")]
impl<H: IntoHttpHandler> HttpServer<H> {
/// Start listening for incoming tls connections.
///
/// This method sets alpn protocols to "h2" and "http/1.1"
pub fn start_ssl(mut self, mut builder: SslAcceptorBuilder) -> io::Result<Addr<Syn, Self>>
{
pub fn start_ssl(
mut self, mut builder: SslAcceptorBuilder
) -> io::Result<Addr<Syn, Self>> {
if self.sockets.is_empty() {
Err(io::Error::new(io::ErrorKind::Other, "No socket addresses are bound"))
Err(io::Error::new(
io::ErrorKind::Other,
"No socket addresses are bound",
))
} else {
// alpn support
if !self.no_http2 {
@ -407,19 +450,29 @@ impl<H: IntoHttpHandler> HttpServer<H>
let (tx, rx) = mpsc::unbounded();
let acceptor = builder.build();
let addrs: Vec<(net::SocketAddr, net::TcpListener)> = self.sockets.drain(..).collect();
let addrs: Vec<(net::SocketAddr, net::TcpListener)> =
self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers(
&settings, &StreamHandlerType::Alpn(acceptor.clone()));
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Alpn(acceptor)};
&settings,
&StreamHandlerType::Alpn(acceptor.clone()),
);
let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Alpn(acceptor),
};
// start acceptors threads
for (addr, sock) in addrs {
info!("Starting server on https://{}", addr);
self.accept.push(
start_accept_thread(
sock, addr, self.backlog,
tx.clone(), info.clone(), workers.clone()));
self.accept.push(start_accept_thread(
sock,
addr,
self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
}
// start http server actor
@ -428,22 +481,23 @@ impl<H: IntoHttpHandler> HttpServer<H>
ctx.add_stream(rx);
self
});
signals.map(|signals| signals.do_send(
signal::Subscribe(addr.clone().recipient())));
signals.map(|signals| {
signals.do_send(signal::Subscribe(addr.clone().recipient()))
});
Ok(addr)
}
}
}
impl<H: IntoHttpHandler> HttpServer<H>
{
impl<H: IntoHttpHandler> HttpServer<H> {
/// Start listening for incoming connections from a stream.
///
/// This method uses only one thread for handling incoming connections.
pub fn start_incoming<T, A, S>(mut self, stream: S, secure: bool) -> Addr<Syn, Self>
where S: Stream<Item=(T, A), Error=io::Error> + 'static,
T: AsyncRead + AsyncWrite + 'static,
A: 'static
where
S: Stream<Item = (T, A), Error = io::Error> + 'static,
T: AsyncRead + AsyncWrite + 'static,
A: 'static,
{
let (tx, rx) = mpsc::unbounded();
@ -452,15 +506,22 @@ impl<H: IntoHttpHandler> HttpServer<H>
self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers(&settings, &StreamHandlerType::Normal);
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Normal};
let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Normal,
};
// start acceptors threads
for (addr, sock) in addrs {
info!("Starting server on http://{}", addr);
self.accept.push(
start_accept_thread(
sock, addr, self.backlog,
tx.clone(), info.clone(), workers.clone()));
self.accept.push(start_accept_thread(
sock,
addr,
self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
}
}
@ -468,21 +529,24 @@ impl<H: IntoHttpHandler> HttpServer<H>
let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
let settings = ServerSettings::new(Some(addr), &self.host, secure);
let apps: Vec<_> = (*self.factory)()
.into_iter().map(|h| h.into_handler(settings.clone())).collect();
.into_iter()
.map(|h| h.into_handler(settings.clone()))
.collect();
self.h = Some(Rc::new(WorkerSettings::new(apps, self.keep_alive)));
// start server
let signals = self.subscribe_to_signals();
let addr: Addr<Syn, _> = HttpServer::create(move |ctx| {
ctx.add_stream(rx);
ctx.add_message_stream(
stream
.map_err(|_| ())
.map(move |(t, _)| Conn{io: WrapperStream::new(t), peer: None, http2: false}));
ctx.add_message_stream(stream.map_err(|_| ()).map(move |(t, _)| Conn {
io: WrapperStream::new(t),
peer: None,
http2: false,
}));
self
});
signals.map(|signals| signals.do_send(
signal::Subscribe(addr.clone().recipient())));
signals
.map(|signals| signals.do_send(signal::Subscribe(addr.clone().recipient())));
addr
}
}
@ -490,8 +554,7 @@ impl<H: IntoHttpHandler> HttpServer<H>
/// Signals support
/// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and send `SystemExit(0)`
/// message to `System` actor.
impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H>
{
impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H> {
type Result = ();
fn handle(&mut self, msg: signal::Signal, ctx: &mut Context<Self>) {
@ -499,17 +562,17 @@ impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H>
signal::SignalType::Int => {
info!("SIGINT received, exiting");
self.exit = true;
Handler::<StopServer>::handle(self, StopServer{graceful: false}, ctx);
Handler::<StopServer>::handle(self, StopServer { graceful: false }, ctx);
}
signal::SignalType::Term => {
info!("SIGTERM received, stopping");
self.exit = true;
Handler::<StopServer>::handle(self, StopServer{graceful: true}, ctx);
Handler::<StopServer>::handle(self, StopServer { graceful: true }, ctx);
}
signal::SignalType::Quit => {
info!("SIGQUIT received, exiting");
self.exit = true;
Handler::<StopServer>::handle(self, StopServer{graceful: false}, ctx);
Handler::<StopServer>::handle(self, StopServer { graceful: false }, ctx);
}
_ => (),
}
@ -517,8 +580,7 @@ impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H>
}
/// Commands from accept threads
impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
{
impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H> {
fn finished(&mut self, _: &mut Context<Self>) {}
fn handle(&mut self, msg: ServerCommand, _: &mut Context<Self>) {
match msg {
@ -528,7 +590,7 @@ impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
if self.workers[i].0 == idx {
self.workers.swap_remove(i);
found = true;
break
break;
}
}
@ -541,21 +603,23 @@ impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
for i in 0..self.workers.len() {
if self.workers[i].0 == new_idx {
new_idx += 1;
continue 'found
continue 'found;
}
}
break
break;
}
let h = info.handler;
let ka = self.keep_alive;
let factory = Arc::clone(&self.factory);
let settings = ServerSettings::new(Some(info.addr), &self.host, false);
let settings =
ServerSettings::new(Some(info.addr), &self.host, false);
let addr = Arbiter::start(move |ctx: &mut Context<_>| {
let apps: Vec<_> = (*factory)()
.into_iter()
.map(|h| h.into_handler(settings.clone())).collect();
.map(|h| h.into_handler(settings.clone()))
.collect();
ctx.add_message_stream(rx);
Worker::new(apps, h, ka)
});
@ -566,30 +630,32 @@ impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
self.workers.push((new_idx, addr));
}
},
}
}
}
}
impl<T, H> Handler<Conn<T>> for HttpServer<H>
where T: IoStream,
H: IntoHttpHandler,
where
T: IoStream,
H: IntoHttpHandler,
{
type Result = ();
fn handle(&mut self, msg: Conn<T>, _: &mut Context<Self>) -> Self::Result {
Arbiter::handle().spawn(
HttpChannel::new(
Rc::clone(self.h.as_ref().unwrap()), msg.io, msg.peer, msg.http2));
Arbiter::handle().spawn(HttpChannel::new(
Rc::clone(self.h.as_ref().unwrap()),
msg.io,
msg.peer,
msg.http2,
));
}
}
impl<H: IntoHttpHandler> Handler<PauseServer> for HttpServer<H>
{
impl<H: IntoHttpHandler> Handler<PauseServer> for HttpServer<H> {
type Result = ();
fn handle(&mut self, _: PauseServer, _: &mut Context<Self>)
{
fn handle(&mut self, _: PauseServer, _: &mut Context<Self>) {
for item in &self.accept {
let _ = item.1.send(Command::Pause);
let _ = item.0.set_readiness(mio::Ready::readable());
@ -597,8 +663,7 @@ impl<H: IntoHttpHandler> Handler<PauseServer> for HttpServer<H>
}
}
impl<H: IntoHttpHandler> Handler<ResumeServer> for HttpServer<H>
{
impl<H: IntoHttpHandler> Handler<ResumeServer> for HttpServer<H> {
type Result = ();
fn handle(&mut self, _: ResumeServer, _: &mut Context<Self>) {
@ -609,8 +674,7 @@ impl<H: IntoHttpHandler> Handler<ResumeServer> for HttpServer<H>
}
}
impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H>
{
impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H> {
type Result = actix::Response<(), ()>;
fn handle(&mut self, msg: StopServer, ctx: &mut Context<Self>) -> Self::Result {
@ -630,7 +694,9 @@ impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H>
};
for worker in &self.workers {
let tx2 = tx.clone();
worker.1.send(StopWorker{graceful: dur})
worker
.1
.send(StopWorker { graceful: dur })
.into_actor(self)
.then(move |_, slf, ctx| {
slf.workers.pop();
@ -645,12 +711,12 @@ impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H>
}
}
actix::fut::ok(())
}).spawn(ctx);
})
.spawn(ctx);
}
if !self.workers.is_empty() {
Response::async(
rx.into_future().map(|_| ()).map_err(|_| ()))
Response::async(rx.into_future().map(|_| ()).map_err(|_| ()))
} else {
// we need to stop system if server was spawned
if self.exit {
@ -673,156 +739,184 @@ enum Command {
fn start_accept_thread(
sock: net::TcpListener, addr: net::SocketAddr, backlog: i32,
srv: mpsc::UnboundedSender<ServerCommand>, info: Info,
mut workers: Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)>)
-> (mio::SetReadiness, sync_mpsc::Sender<Command>)
{
mut workers: Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)>,
) -> (mio::SetReadiness, sync_mpsc::Sender<Command>) {
let (tx, rx) = sync_mpsc::channel();
let (reg, readiness) = mio::Registration::new2();
// start accept thread
#[cfg_attr(feature="cargo-clippy", allow(cyclomatic_complexity))]
let _ = thread::Builder::new().name(format!("Accept on {}", addr)).spawn(move || {
const SRV: mio::Token = mio::Token(0);
const CMD: mio::Token = mio::Token(1);
#[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))]
let _ = thread::Builder::new()
.name(format!("Accept on {}", addr))
.spawn(move || {
const SRV: mio::Token = mio::Token(0);
const CMD: mio::Token = mio::Token(1);
let mut server = Some(
mio::net::TcpListener::from_std(sock)
.expect("Can not create mio::net::TcpListener"));
let mut server = Some(
mio::net::TcpListener::from_std(sock)
.expect("Can not create mio::net::TcpListener"),
);
// Create a poll instance
let poll = match mio::Poll::new() {
Ok(poll) => poll,
Err(err) => panic!("Can not create mio::Poll: {}", err),
};
// Create a poll instance
let poll = match mio::Poll::new() {
Ok(poll) => poll,
Err(err) => panic!("Can not create mio::Poll: {}", err),
};
// Start listening for incoming connections
if let Some(ref srv) = server {
if let Err(err) = poll.register(
srv, SRV, mio::Ready::readable(), mio::PollOpt::edge()) {
panic!("Can not register io: {}", err);
}
}
// Start listening for incoming commands
if let Err(err) = poll.register(&reg, CMD,
mio::Ready::readable(), mio::PollOpt::edge()) {
panic!("Can not register Registration: {}", err);
}
// Create storage for events
let mut events = mio::Events::with_capacity(128);
// Sleep on error
let sleep = Duration::from_millis(100);
let mut next = 0;
loop {
if let Err(err) = poll.poll(&mut events, None) {
panic!("Poll error: {}", err);
}
for event in events.iter() {
match event.token() {
SRV => if let Some(ref server) = server {
loop {
match server.accept_std() {
Ok((sock, addr)) => {
let mut msg = Conn{
io: sock, peer: Some(addr), http2: false};
while !workers.is_empty() {
match workers[next].1.unbounded_send(msg) {
Ok(_) => (),
Err(err) => {
let _ = srv.unbounded_send(
ServerCommand::WorkerDied(
workers[next].0, info.clone()));
msg = err.into_inner();
workers.swap_remove(next);
if workers.is_empty() {
error!("No workers");
thread::sleep(sleep);
break
} else if workers.len() <= next {
next = 0;
}
continue
}
}
next = (next + 1) % workers.len();
break
}
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
break,
Err(ref e) if connection_error(e) =>
continue,
Err(e) => {
error!("Error accepting connection: {}", e);
// sleep after error
thread::sleep(sleep);
break
}
}
}
},
CMD => match rx.try_recv() {
Ok(cmd) => match cmd {
Command::Pause => if let Some(server) = server.take() {
if let Err(err) = poll.deregister(&server) {
error!("Can not deregister server socket {}", err);
} else {
info!("Paused accepting connections on {}", addr);
}
},
Command::Resume => {
let lst = create_tcp_listener(addr, backlog)
.expect("Can not create net::TcpListener");
server = Some(
mio::net::TcpListener::from_std(lst)
.expect("Can not create mio::net::TcpListener"));
if let Some(ref server) = server {
if let Err(err) = poll.register(
server, SRV, mio::Ready::readable(), mio::PollOpt::edge())
{
error!("Can not resume socket accept process: {}", err);
} else {
info!("Accepting connections on {} has been resumed",
addr);
}
}
},
Command::Stop => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return
},
Command::Worker(idx, addr) => {
workers.push((idx, addr));
},
},
Err(err) => match err {
sync_mpsc::TryRecvError::Empty => (),
sync_mpsc::TryRecvError::Disconnected => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return
},
}
},
_ => unreachable!(),
// Start listening for incoming connections
if let Some(ref srv) = server {
if let Err(err) =
poll.register(srv, SRV, mio::Ready::readable(), mio::PollOpt::edge())
{
panic!("Can not register io: {}", err);
}
}
}
});
// Start listening for incoming commands
if let Err(err) = poll.register(
&reg,
CMD,
mio::Ready::readable(),
mio::PollOpt::edge(),
) {
panic!("Can not register Registration: {}", err);
}
// Create storage for events
let mut events = mio::Events::with_capacity(128);
// Sleep on error
let sleep = Duration::from_millis(100);
let mut next = 0;
loop {
if let Err(err) = poll.poll(&mut events, None) {
panic!("Poll error: {}", err);
}
for event in events.iter() {
match event.token() {
SRV => if let Some(ref server) = server {
loop {
match server.accept_std() {
Ok((sock, addr)) => {
let mut msg = Conn {
io: sock,
peer: Some(addr),
http2: false,
};
while !workers.is_empty() {
match workers[next].1.unbounded_send(msg) {
Ok(_) => (),
Err(err) => {
let _ = srv.unbounded_send(
ServerCommand::WorkerDied(
workers[next].0,
info.clone(),
),
);
msg = err.into_inner();
workers.swap_remove(next);
if workers.is_empty() {
error!("No workers");
thread::sleep(sleep);
break;
} else if workers.len() <= next {
next = 0;
}
continue;
}
}
next = (next + 1) % workers.len();
break;
}
}
Err(ref e)
if e.kind() == io::ErrorKind::WouldBlock =>
{
break
}
Err(ref e) if connection_error(e) => continue,
Err(e) => {
error!("Error accepting connection: {}", e);
// sleep after error
thread::sleep(sleep);
break;
}
}
}
},
CMD => match rx.try_recv() {
Ok(cmd) => match cmd {
Command::Pause => if let Some(server) = server.take() {
if let Err(err) = poll.deregister(&server) {
error!(
"Can not deregister server socket {}",
err
);
} else {
info!(
"Paused accepting connections on {}",
addr
);
}
},
Command::Resume => {
let lst = create_tcp_listener(addr, backlog)
.expect("Can not create net::TcpListener");
server = Some(
mio::net::TcpListener::from_std(lst).expect(
"Can not create mio::net::TcpListener",
),
);
if let Some(ref server) = server {
if let Err(err) = poll.register(
server,
SRV,
mio::Ready::readable(),
mio::PollOpt::edge(),
) {
error!("Can not resume socket accept process: {}", err);
} else {
info!("Accepting connections on {} has been resumed",
addr);
}
}
}
Command::Stop => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return;
}
Command::Worker(idx, addr) => {
workers.push((idx, addr));
}
},
Err(err) => match err {
sync_mpsc::TryRecvError::Empty => (),
sync_mpsc::TryRecvError::Disconnected => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return;
}
},
},
_ => unreachable!(),
}
}
}
});
(readiness, tx)
}
fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result<net::TcpListener> {
fn create_tcp_listener(
addr: net::SocketAddr, backlog: i32
) -> io::Result<net::TcpListener> {
let builder = match addr {
net::SocketAddr::V4(_) => TcpBuilder::new_v4()?,
net::SocketAddr::V6(_) => TcpBuilder::new_v6()?,
@ -840,7 +934,7 @@ fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result<net::T
/// The timeout is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop.
fn connection_error(e: &io::Error) -> bool {
e.kind() == io::ErrorKind::ConnectionRefused ||
e.kind() == io::ErrorKind::ConnectionAborted ||
e.kind() == io::ErrorKind::ConnectionReset
e.kind() == io::ErrorKind::ConnectionRefused
|| e.kind() == io::ErrorKind::ConnectionAborted
|| e.kind() == io::ErrorKind::ConnectionReset
}

View File

@ -1,14 +1,15 @@
use std::io;
use bytes::{BytesMut, BufMut};
use bytes::{BufMut, BytesMut};
use futures::{Async, Poll};
use std::io;
use super::IoStream;
const LW_BUFFER_SIZE: usize = 4096;
const HW_BUFFER_SIZE: usize = 32_768;
pub fn read_from_io<T: IoStream>(io: &mut T, buf: &mut BytesMut) -> Poll<usize, io::Error> {
pub fn read_from_io<T: IoStream>(
io: &mut T, buf: &mut BytesMut
) -> Poll<usize, io::Error> {
unsafe {
if buf.remaining_mut() < LW_BUFFER_SIZE {
buf.reserve(HW_BUFFER_SIZE);
@ -17,7 +18,7 @@ pub fn read_from_io<T: IoStream>(io: &mut T, buf: &mut BytesMut) -> Poll<usize,
Ok(n) => {
buf.advance_mut(n);
Ok(Async::Ready(n))
},
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
Ok(Async::NotReady)

View File

@ -1,31 +1,30 @@
use std::{net, time};
use std::rc::Rc;
use futures::Future;
use futures::unsync::oneshot;
use net2::TcpStreamExt;
use std::rc::Rc;
use std::{net, time};
use tokio_core::net::TcpStream;
use tokio_core::reactor::Handle;
use net2::TcpStreamExt;
#[cfg(any(feature="tls", feature="alpn"))]
#[cfg(any(feature = "tls", feature = "alpn"))]
use futures::future;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
use native_tls::TlsAcceptor;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
use tokio_tls::TlsAcceptorExt;
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
use openssl::ssl::SslAcceptor;
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
use tokio_openssl::SslAcceptorExt;
use actix::*;
use actix::msgs::StopArbiter;
use actix::*;
use server::{HttpHandler, KeepAlive};
use server::channel::HttpChannel;
use server::settings::WorkerSettings;
use server::{HttpHandler, KeepAlive};
#[derive(Message)]
pub(crate) struct Conn<T> {
@ -46,9 +45,12 @@ impl Message for StopWorker {
/// Http worker
///
/// Worker accepts Socket objects via unbounded channel and start requests processing.
pub(crate)
struct Worker<H> where H: HttpHandler + 'static {
/// Worker accepts Socket objects via unbounded channel and start requests
/// processing.
pub(crate) struct Worker<H>
where
H: HttpHandler + 'static,
{
settings: Rc<WorkerSettings<H>>,
hnd: Handle,
handler: StreamHandlerType,
@ -56,10 +58,9 @@ struct Worker<H> where H: HttpHandler + 'static {
}
impl<H: HttpHandler + 'static> Worker<H> {
pub(crate) fn new(h: Vec<H>, handler: StreamHandlerType, keep_alive: KeepAlive)
-> Worker<H>
{
pub(crate) fn new(
h: Vec<H>, handler: StreamHandlerType, keep_alive: KeepAlive
) -> Worker<H> {
let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive {
Some(time::Duration::new(val as u64, 0))
} else {
@ -76,11 +77,14 @@ impl<H: HttpHandler + 'static> Worker<H> {
fn update_time(&self, ctx: &mut Context<Self>) {
self.settings.update_date();
ctx.run_later(time::Duration::new(1, 0), |slf, ctx| slf.update_time(ctx));
ctx.run_later(time::Duration::new(1, 0), |slf, ctx| {
slf.update_time(ctx)
});
}
fn shutdown_timeout(&self, ctx: &mut Context<Self>,
tx: oneshot::Sender<bool>, dur: time::Duration) {
fn shutdown_timeout(
&self, ctx: &mut Context<Self>, tx: oneshot::Sender<bool>, dur: time::Duration
) {
// sleep for 1 second and then check again
ctx.run_later(time::Duration::new(1, 0), move |slf, ctx| {
let num = slf.settings.num_channels();
@ -99,7 +103,10 @@ impl<H: HttpHandler + 'static> Worker<H> {
}
}
impl<H: 'static> Actor for Worker<H> where H: HttpHandler + 'static {
impl<H: 'static> Actor for Worker<H>
where
H: HttpHandler + 'static,
{
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
@ -108,22 +115,24 @@ impl<H: 'static> Actor for Worker<H> where H: HttpHandler + 'static {
}
impl<H> Handler<Conn<net::TcpStream>> for Worker<H>
where H: HttpHandler + 'static,
where
H: HttpHandler + 'static,
{
type Result = ();
fn handle(&mut self, msg: Conn<net::TcpStream>, _: &mut Context<Self>)
{
fn handle(&mut self, msg: Conn<net::TcpStream>, _: &mut Context<Self>) {
if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() {
error!("Can not set socket keep-alive option");
}
self.handler.handle(Rc::clone(&self.settings), &self.hnd, msg);
self.handler
.handle(Rc::clone(&self.settings), &self.hnd, msg);
}
}
/// `StopWorker` message handler
impl<H> Handler<StopWorker> for Worker<H>
where H: HttpHandler + 'static,
where
H: HttpHandler + 'static,
{
type Result = Response<bool, ()>;
@ -148,17 +157,16 @@ impl<H> Handler<StopWorker> for Worker<H>
#[derive(Clone)]
pub(crate) enum StreamHandlerType {
Normal,
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
Tls(TlsAcceptor),
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
Alpn(SslAcceptor),
}
impl StreamHandlerType {
fn handle<H: HttpHandler>(&mut self,
h: Rc<WorkerSettings<H>>,
hnd: &Handle, msg: Conn<net::TcpStream>) {
fn handle<H: HttpHandler>(
&mut self, h: Rc<WorkerSettings<H>>, hnd: &Handle, msg: Conn<net::TcpStream>
) {
match *self {
StreamHandlerType::Normal => {
let _ = msg.io.set_nodelay(true);
@ -167,7 +175,7 @@ impl StreamHandlerType {
hnd.spawn(HttpChannel::new(h, io, msg.peer, msg.http2));
}
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
StreamHandlerType::Tls(ref acceptor) => {
let Conn { io, peer, http2 } = msg;
let _ = io.set_nodelay(true);
@ -177,16 +185,21 @@ impl StreamHandlerType {
hnd.spawn(
TlsAcceptorExt::accept_async(acceptor, io).then(move |res| {
match res {
Ok(io) => Arbiter::handle().spawn(
HttpChannel::new(h, io, peer, http2)),
Err(err) =>
trace!("Error during handling tls connection: {}", err),
Ok(io) => Arbiter::handle().spawn(HttpChannel::new(
h,
io,
peer,
http2,
)),
Err(err) => {
trace!("Error during handling tls connection: {}", err)
}
};
future::result(Ok(()))
})
}),
);
}
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
StreamHandlerType::Alpn(ref acceptor) => {
let Conn { io, peer, .. } = msg;
let _ = io.set_nodelay(true);
@ -197,20 +210,26 @@ impl StreamHandlerType {
SslAcceptorExt::accept_async(acceptor, io).then(move |res| {
match res {
Ok(io) => {
let http2 = if let Some(p) = io.get_ref().ssl().selected_alpn_protocol()
let http2 = if let Some(p) =
io.get_ref().ssl().selected_alpn_protocol()
{
p.len() == 2 && &p == b"h2"
} else {
false
};
Arbiter::handle().spawn(
HttpChannel::new(h, io, peer, http2));
},
Err(err) =>
trace!("Error during handling tls connection: {}", err),
Arbiter::handle().spawn(HttpChannel::new(
h,
io,
peer,
http2,
));
}
Err(err) => {
trace!("Error during handling tls connection: {}", err)
}
};
future::result(Ok(()))
})
}),
);
}
}

View File

@ -1,37 +1,37 @@
//! Various helpers for Actix applications to use during testing.
use std::{net, thread};
use std::rc::Rc;
use std::sync::mpsc;
use std::str::FromStr;
use std::sync::mpsc;
use std::{net, thread};
use actix::{Actor, Arbiter, Addr, Syn, System, SystemRunner, Unsync, msgs};
use actix::{msgs, Actor, Addr, Arbiter, Syn, System, SystemRunner, Unsync};
use cookie::Cookie;
use http::{Uri, Method, Version, HeaderMap, HttpTryFrom};
use http::header::HeaderName;
use futures::Future;
use http::header::HeaderName;
use http::{HeaderMap, HttpTryFrom, Method, Uri, Version};
use net2::TcpBuilder;
use tokio_core::net::TcpListener;
use tokio_core::reactor::Core;
use net2::TcpBuilder;
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
use openssl::ssl::SslAcceptor;
use ws;
use body::Binary;
use error::Error;
use header::{Header, IntoHeaderValue};
use handler::{Handler, Responder, ReplyItem};
use middleware::Middleware;
use application::{App, HttpApplication};
use param::Params;
use router::Router;
use payload::Payload;
use resource::ResourceHandler;
use body::Binary;
use client::{ClientConnector, ClientRequest, ClientRequestBuilder};
use error::Error;
use handler::{Handler, ReplyItem, Responder};
use header::{Header, IntoHeaderValue};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::Middleware;
use param::Params;
use payload::Payload;
use resource::ResourceHandler;
use router::Router;
use server::{HttpServer, IntoHttpHandler, ServerSettings};
use client::{ClientRequest, ClientRequestBuilder, ClientConnector};
use ws;
/// The `TestServer` type.
///
@ -69,20 +69,20 @@ pub struct TestServer {
}
impl TestServer {
/// Start new test server
///
/// This method accepts configuration method. You can add
/// middlewares or set handlers for test application.
pub fn new<F>(config: F) -> Self
where F: Sync + Send + 'static + Fn(&mut TestApp<()>)
where
F: Sync + Send + 'static + Fn(&mut TestApp<()>),
{
TestServerBuilder::new(||()).start(config)
TestServerBuilder::new(|| ()).start(config)
}
/// Create test server builder
pub fn build() -> TestServerBuilder<()> {
TestServerBuilder::new(||())
TestServerBuilder::new(|| ())
}
/// Create test server builder with specific state factory
@ -91,17 +91,19 @@ impl TestServer {
/// Also it can be used for external dependecy initialization,
/// like creating sync actors for diesel integration.
pub fn build_with_state<F, S>(state: F) -> TestServerBuilder<S>
where F: Fn() -> S + Sync + Send + 'static,
S: 'static,
where
F: Fn() -> S + Sync + Send + 'static,
S: 'static,
{
TestServerBuilder::new(state)
}
/// Start new test server with application factory
pub fn with_factory<F, U, H>(factory: F) -> Self
where F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item=H> + 'static,
H: IntoHttpHandler + 'static,
where
F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item = H> + 'static,
H: IntoHttpHandler + 'static,
{
let (tx, rx) = mpsc::channel();
@ -110,8 +112,8 @@ impl TestServer {
let sys = System::new("actix-test-server");
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap();
let tcp = TcpListener::from_listener(
tcp, &local_addr, Arbiter::handle()).unwrap();
let tcp =
TcpListener::from_listener(tcp, &local_addr, Arbiter::handle()).unwrap();
HttpServer::new(factory)
.disable_signals()
@ -134,15 +136,15 @@ impl TestServer {
}
fn get_conn() -> Addr<Unsync, ClientConnector> {
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
{
use openssl::ssl::{SslMethod, SslConnector, SslVerifyMode};
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
ClientConnector::with_connector(builder.build()).start()
}
#[cfg(not(feature="alpn"))]
#[cfg(not(feature = "alpn"))]
{
ClientConnector::default().start()
}
@ -166,9 +168,19 @@ impl TestServer {
/// Construct test server url
pub fn url(&self, uri: &str) -> String {
if uri.starts_with('/') {
format!("{}://{}{}", if self.ssl {"https"} else {"http"}, self.addr, uri)
format!(
"{}://{}{}",
if self.ssl { "https" } else { "http" },
self.addr,
uri
)
} else {
format!("{}://{}/{}", if self.ssl {"https"} else {"http"}, self.addr, uri)
format!(
"{}://{}/{}",
if self.ssl { "https" } else { "http" },
self.addr,
uri
)
}
}
@ -182,16 +194,20 @@ impl TestServer {
/// Execute future on current core
pub fn execute<F, I, E>(&mut self, fut: F) -> Result<I, E>
where F: Future<Item=I, Error=E>
where
F: Future<Item = I, Error = E>,
{
self.system.run_until_complete(fut)
}
/// Connect to websocket server
pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> {
pub fn ws(
&mut self
) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> {
let url = self.url("/");
self.system.run_until_complete(
ws::Client::with_connector(url, self.conn.clone()).connect())
ws::Client::with_connector(url, self.conn.clone()).connect(),
)
}
/// Create `GET` request
@ -231,23 +247,23 @@ impl Drop for TestServer {
/// builder-like pattern.
pub struct TestServerBuilder<S> {
state: Box<Fn() -> S + Sync + Send + 'static>,
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
ssl: Option<SslAcceptor>,
}
impl<S: 'static> TestServerBuilder<S> {
pub fn new<F>(state: F) -> TestServerBuilder<S>
where F: Fn() -> S + Sync + Send + 'static
where
F: Fn() -> S + Sync + Send + 'static,
{
TestServerBuilder {
state: Box::new(state),
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
ssl: None,
}
}
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
/// Create ssl server
pub fn ssl(mut self, ssl: SslAcceptor) -> Self {
self.ssl = Some(ssl);
@ -257,13 +273,14 @@ impl<S: 'static> TestServerBuilder<S> {
#[allow(unused_mut)]
/// Configure test application and run test server
pub fn start<F>(mut self, config: F) -> TestServer
where F: Sync + Send + 'static + Fn(&mut TestApp<S>),
where
F: Sync + Send + 'static + Fn(&mut TestApp<S>),
{
let (tx, rx) = mpsc::channel();
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
let ssl = self.ssl.is_some();
#[cfg(not(feature="alpn"))]
#[cfg(not(feature = "alpn"))]
let ssl = false;
// run server in separate thread
@ -272,38 +289,38 @@ impl<S: 'static> TestServerBuilder<S> {
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap();
let tcp = TcpListener::from_listener(
tcp, &local_addr, Arbiter::handle()).unwrap();
let tcp =
TcpListener::from_listener(tcp, &local_addr, Arbiter::handle()).unwrap();
let state = self.state;
let srv = HttpServer::new(move || {
let mut app = TestApp::new(state());
config(&mut app);
vec![app]})
.disable_signals();
vec![app]
}).disable_signals();
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
{
use std::io;
use futures::Stream;
use std::io;
use tokio_openssl::SslAcceptorExt;
let ssl = self.ssl.take();
if let Some(ssl) = ssl {
srv.start_incoming(
tcp.incoming()
.and_then(move |(sock, addr)| {
ssl.accept_async(sock)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.map(move |s| (s, addr))
}),
false);
tcp.incoming().and_then(move |(sock, addr)| {
ssl.accept_async(sock)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.map(move |s| (s, addr))
}),
false,
);
} else {
srv.start_incoming(tcp.incoming(), false);
}
}
#[cfg(not(feature="alpn"))]
#[cfg(not(feature = "alpn"))]
{
srv.start_incoming(tcp.incoming(), false);
}
@ -326,24 +343,30 @@ impl<S: 'static> TestServerBuilder<S> {
}
/// Test application helper for testing request handlers.
pub struct TestApp<S=()> {
pub struct TestApp<S = ()> {
app: Option<App<S>>,
}
impl<S: 'static> TestApp<S> {
fn new(state: S) -> TestApp<S> {
let app = App::with_state(state);
TestApp{app: Some(app)}
TestApp { app: Some(app) }
}
/// Register handler for "/"
pub fn handler<H: Handler<S>>(&mut self, handler: H) {
self.app = Some(self.app.take().unwrap().resource("/", |r| r.h(handler)));
self.app = Some(
self.app
.take()
.unwrap()
.resource("/", |r| r.h(handler)),
);
}
/// Register middleware
pub fn middleware<T>(&mut self, mw: T) -> &mut TestApp<S>
where T: Middleware<S> + 'static
where
T: Middleware<S> + 'static,
{
self.app = Some(self.app.take().unwrap().middleware(mw));
self
@ -352,7 +375,8 @@ impl<S: 'static> TestApp<S> {
/// Register resource. This method is similar
/// to `App::resource()` method.
pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut TestApp<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static
where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{
self.app = Some(self.app.take().unwrap().resource(path, f));
self
@ -419,7 +443,6 @@ pub struct TestRequest<S> {
}
impl Default for TestRequest<()> {
fn default() -> TestRequest<()> {
TestRequest {
state: (),
@ -435,28 +458,27 @@ impl Default for TestRequest<()> {
}
impl TestRequest<()> {
/// Create TestRequest and set request uri
pub fn with_uri(path: &str) -> TestRequest<()> {
TestRequest::default().uri(path)
}
/// Create TestRequest and set header
pub fn with_hdr<H: Header>(hdr: H) -> TestRequest<()>
{
pub fn with_hdr<H: Header>(hdr: H) -> TestRequest<()> {
TestRequest::default().set(hdr)
}
/// Create TestRequest and set header
pub fn with_header<K, V>(key: K, value: V) -> TestRequest<()>
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue,
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
TestRequest::default().header(key, value)
}
}
impl<S> TestRequest<S> {
/// Start HttpRequest build process with application state
pub fn with_state(state: S) -> TestRequest<S> {
TestRequest {
@ -490,23 +512,24 @@ impl<S> TestRequest<S> {
}
/// Set a header
pub fn set<H: Header>(mut self, hdr: H) -> Self
{
pub fn set<H: Header>(mut self, hdr: H) -> Self {
if let Ok(value) = hdr.try_into() {
self.headers.append(H::name(), value);
return self
return self;
}
panic!("Can not set header");
}
/// Set a header
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Ok(key) = HeaderName::try_from(key) {
if let Ok(value) = value.try_into() {
self.headers.append(key, value);
return self
return self;
}
}
panic!("Can not create header");
@ -529,7 +552,16 @@ impl<S> TestRequest<S> {
/// Complete request creation and generate `HttpRequest` instance
pub fn finish(self) -> HttpRequest<S> {
let TestRequest { state, method, uri, version, headers, params, cookies, payload } = self;
let TestRequest {
state,
method,
uri,
version,
headers,
params,
cookies,
payload,
} = self;
let req = HttpRequest::new(method, uri, version, headers, payload);
req.as_mut().cookies = cookies;
req.as_mut().params = params;
@ -540,8 +572,16 @@ impl<S> TestRequest<S> {
#[cfg(test)]
/// Complete request creation and generate `HttpRequest` instance
pub(crate) fn finish_with_router(self, router: Router) -> HttpRequest<S> {
let TestRequest { state, method, uri,
version, headers, params, cookies, payload } = self;
let TestRequest {
state,
method,
uri,
version,
headers,
params,
cookies,
payload,
} = self;
let req = HttpRequest::new(method, uri, version, headers, payload);
req.as_mut().cookies = cookies;
@ -553,18 +593,16 @@ impl<S> TestRequest<S> {
/// with generated request.
///
/// This method panics is handler returns actor or async result.
pub fn run<H: Handler<S>>(self, mut h: H) ->
Result<HttpResponse, <<H as Handler<S>>::Result as Responder>::Error>
{
pub fn run<H: Handler<S>>(
self, mut h: H
) -> Result<HttpResponse, <<H as Handler<S>>::Result as Responder>::Error> {
let req = self.finish();
let resp = h.handle(req.clone());
match resp.respond_to(req.drop_state()) {
Ok(resp) => {
match resp.into().into() {
ReplyItem::Message(resp) => Ok(resp),
ReplyItem::Future(_) => panic!("Async handler is not supported."),
}
Ok(resp) => match resp.into().into() {
ReplyItem::Message(resp) => Ok(resp),
ReplyItem::Future(_) => panic!("Async handler is not supported."),
},
Err(err) => Err(err),
}
@ -575,24 +613,23 @@ impl<S> TestRequest<S> {
///
/// This method panics is handler returns actor.
pub fn run_async<H, R, F, E>(self, h: H) -> Result<HttpResponse, E>
where H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static,
R: Responder<Error=E> + 'static,
E: Into<Error> + 'static
where
H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item = R, Error = E> + 'static,
R: Responder<Error = E> + 'static,
E: Into<Error> + 'static,
{
let req = self.finish();
let fut = h(req.clone());
let mut core = Core::new().unwrap();
match core.run(fut) {
Ok(r) => {
match r.respond_to(req.drop_state()) {
Ok(reply) => match reply.into().into() {
ReplyItem::Message(resp) => Ok(resp),
_ => panic!("Nested async replies are not supported"),
},
Err(e) => Err(e),
}
Ok(r) => match r.respond_to(req.drop_state()) {
Ok(reply) => match reply.into().into() {
ReplyItem::Message(resp) => Ok(resp),
_ => panic!("Nested async replies are not supported"),
},
Err(e) => Err(e),
},
Err(err) => Err(err),
}

View File

@ -1,33 +1,37 @@
use std::rc::Rc;
use futures::{Async, Future, Poll};
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use futures::{Async, Future, Poll};
use std::rc::Rc;
use error::Error;
use handler::{Handler, FromRequest, Reply, ReplyItem, Responder};
use handler::{FromRequest, Handler, Reply, ReplyItem, Responder};
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
pub struct ExtractorConfig<S: 'static, T: FromRequest<S>> {
cfg: Rc<UnsafeCell<T::Config>>
cfg: Rc<UnsafeCell<T::Config>>,
}
impl<S: 'static, T: FromRequest<S>> Default for ExtractorConfig<S, T> {
fn default() -> Self {
ExtractorConfig { cfg: Rc::new(UnsafeCell::new(T::Config::default())) }
ExtractorConfig {
cfg: Rc::new(UnsafeCell::new(T::Config::default())),
}
}
}
impl<S: 'static, T: FromRequest<S>> Clone for ExtractorConfig<S, T> {
fn clone(&self) -> Self {
ExtractorConfig { cfg: Rc::clone(&self.cfg) }
ExtractorConfig {
cfg: Rc::clone(&self.cfg),
}
}
}
impl<S: 'static, T: FromRequest<S>> AsRef<T::Config> for ExtractorConfig<S, T> {
fn as_ref(&self) -> &T::Config {
unsafe{&*self.cfg.get()}
unsafe { &*self.cfg.get() }
}
}
@ -35,18 +39,21 @@ impl<S: 'static, T: FromRequest<S>> Deref for ExtractorConfig<S, T> {
type Target = T::Config;
fn deref(&self) -> &T::Config {
unsafe{&*self.cfg.get()}
unsafe { &*self.cfg.get() }
}
}
impl<S: 'static, T: FromRequest<S>> DerefMut for ExtractorConfig<S, T> {
fn deref_mut(&mut self) -> &mut T::Config {
unsafe{&mut *self.cfg.get()}
unsafe { &mut *self.cfg.get() }
}
}
pub struct With<T, S, F, R>
where F: Fn(T) -> R, T: FromRequest<S>, S: 'static,
where
F: Fn(T) -> R,
T: FromRequest<S>,
S: 'static,
{
hnd: Rc<UnsafeCell<F>>,
cfg: ExtractorConfig<S, T>,
@ -54,23 +61,31 @@ pub struct With<T, S, F, R>
}
impl<T, S, F, R> With<T, S, F, R>
where F: Fn(T) -> R, T: FromRequest<S>, S: 'static,
where
F: Fn(T) -> R,
T: FromRequest<S>,
S: 'static,
{
pub fn new(f: F, cfg: ExtractorConfig<S, T>) -> Self {
With{cfg, hnd: Rc::new(UnsafeCell::new(f)), _s: PhantomData}
With {
cfg,
hnd: Rc::new(UnsafeCell::new(f)),
_s: PhantomData,
}
}
}
impl<T, S, F, R> Handler<S> for With<T, S, F, R>
where F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T) -> R + 'static,
R: Responder + 'static,
T: FromRequest<S> + 'static,
S: 'static,
{
type Result = Reply;
fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
let mut fut = WithHandlerFut{
let mut fut = WithHandlerFut {
req,
started: false,
hnd: Rc::clone(&self.hnd),
@ -88,31 +103,33 @@ impl<T, S, F, R> Handler<S> for With<T, S, F, R>
}
struct WithHandlerFut<T, S, F, R>
where F: Fn(T) -> R,
R: Responder,
T: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T) -> R,
R: Responder,
T: FromRequest<S> + 'static,
S: 'static,
{
started: bool,
hnd: Rc<UnsafeCell<F>>,
cfg: ExtractorConfig<S, T>,
req: HttpRequest<S>,
fut1: Option<Box<Future<Item=T, Error=Error>>>,
fut2: Option<Box<Future<Item=HttpResponse, Error=Error>>>,
fut1: Option<Box<Future<Item = T, Error = Error>>>,
fut2: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
}
impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
where F: Fn(T) -> R,
R: Responder + 'static,
T: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T) -> R,
R: Responder + 'static,
T: FromRequest<S> + 'static,
S: 'static,
{
type Item = HttpResponse;
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut2 {
return fut.poll()
return fut.poll();
}
let item = if !self.started {
@ -122,8 +139,8 @@ impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
Ok(Async::Ready(item)) => item,
Ok(Async::NotReady) => {
self.fut1 = Some(Box::new(fut));
return Ok(Async::NotReady)
},
return Ok(Async::NotReady);
}
Err(e) => return Err(e),
}
} else {
@ -133,7 +150,7 @@ impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
}
};
let hnd: &mut F = unsafe{&mut *self.hnd.get()};
let hnd: &mut F = unsafe { &mut *self.hnd.get() };
let item = match (*hnd)(item).respond_to(self.req.drop_state()) {
Ok(item) => item.into(),
Err(e) => return Err(e.into()),
@ -150,8 +167,11 @@ impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
}
pub struct With2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R,
T1: FromRequest<S> + 'static, T2: FromRequest<S> + 'static, S: 'static
where
F: Fn(T1, T2) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{
hnd: Rc<UnsafeCell<F>>,
cfg1: ExtractorConfig<S, T1>,
@ -160,26 +180,36 @@ pub struct With2<T1, T2, S, F, R>
}
impl<T1, T2, S, F, R> With2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R,
T1: FromRequest<S> + 'static, T2: FromRequest<S> + 'static, S: 'static
where
F: Fn(T1, T2) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{
pub fn new(f: F, cfg1: ExtractorConfig<S, T1>, cfg2: ExtractorConfig<S, T2>) -> Self {
With2{hnd: Rc::new(UnsafeCell::new(f)),
cfg1, cfg2, _s: PhantomData}
pub fn new(
f: F, cfg1: ExtractorConfig<S, T1>, cfg2: ExtractorConfig<S, T2>
) -> Self {
With2 {
hnd: Rc::new(UnsafeCell::new(f)),
cfg1,
cfg2,
_s: PhantomData,
}
}
}
impl<T1, T2, S, F, R> Handler<S> for With2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{
type Result = Reply;
fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
let mut fut = WithHandlerFut2{
let mut fut = WithHandlerFut2 {
req,
started: false,
hnd: Rc::clone(&self.hnd),
@ -199,11 +229,12 @@ impl<T1, T2, S, F, R> Handler<S> for With2<T1, T2, S, F, R>
}
struct WithHandlerFut2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{
started: bool,
hnd: Rc<UnsafeCell<F>>,
@ -211,24 +242,25 @@ struct WithHandlerFut2<T1, T2, S, F, R>
cfg2: ExtractorConfig<S, T2>,
req: HttpRequest<S>,
item: Option<T1>,
fut1: Option<Box<Future<Item=T1, Error=Error>>>,
fut2: Option<Box<Future<Item=T2, Error=Error>>>,
fut3: Option<Box<Future<Item=HttpResponse, Error=Error>>>,
fut1: Option<Box<Future<Item = T1, Error = Error>>>,
fut2: Option<Box<Future<Item = T2, Error = Error>>>,
fut3: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
}
impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{
type Item = HttpResponse;
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut3 {
return fut.poll()
return fut.poll();
}
if !self.started {
@ -236,32 +268,32 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
let mut fut = T1::from_request(&self.req, self.cfg1.as_ref());
match fut.poll() {
Ok(Async::Ready(item1)) => {
let mut fut = T2::from_request(&self.req,self.cfg2.as_ref());
let mut fut = T2::from_request(&self.req, self.cfg2.as_ref());
match fut.poll() {
Ok(Async::Ready(item2)) => {
let hnd: &mut F = unsafe{&mut *self.hnd.get()};
match (*hnd)(item1, item2)
.respond_to(self.req.drop_state())
let hnd: &mut F = unsafe { &mut *self.hnd.get() };
match (*hnd)(item1, item2).respond_to(self.req.drop_state())
{
Ok(item) => match item.into().into() {
ReplyItem::Message(resp) =>
return Ok(Async::Ready(resp)),
ReplyItem::Message(resp) => {
return Ok(Async::Ready(resp))
}
ReplyItem::Future(fut) => {
self.fut3 = Some(fut);
return self.poll()
return self.poll();
}
},
Err(e) => return Err(e.into()),
}
},
}
Ok(Async::NotReady) => {
self.item = Some(item1);
self.fut2 = Some(Box::new(fut));
return Ok(Async::NotReady);
},
}
Err(e) => return Err(e),
}
},
}
Ok(Async::NotReady) => {
self.fut1 = Some(Box::new(fut));
return Ok(Async::NotReady);
@ -275,9 +307,11 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
Async::Ready(item) => {
self.item = Some(item);
self.fut1.take();
self.fut2 = Some(Box::new(
T2::from_request(&self.req, self.cfg2.as_ref())));
},
self.fut2 = Some(Box::new(T2::from_request(
&self.req,
self.cfg2.as_ref(),
)));
}
Async::NotReady => return Ok(Async::NotReady),
}
}
@ -287,7 +321,7 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
Async::NotReady => return Ok(Async::NotReady),
};
let hnd: &mut F = unsafe{&mut *self.hnd.get()};
let hnd: &mut F = unsafe { &mut *self.hnd.get() };
let item = match (*hnd)(self.item.take().unwrap(), item)
.respond_to(self.req.drop_state())
{
@ -305,11 +339,12 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
}
pub struct With3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2, T3) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static,
{
hnd: Rc<UnsafeCell<F>>,
cfg1: ExtractorConfig<S, T1>,
@ -318,33 +353,44 @@ pub struct With3<T1, T2, T3, S, F, R>
_s: PhantomData<S>,
}
impl<T1, T2, T3, S, F, R> With3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2, T3) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static,
{
pub fn new(f: F, cfg1: ExtractorConfig<S, T1>,
cfg2: ExtractorConfig<S, T2>, cfg3: ExtractorConfig<S, T3>) -> Self
{
With3{hnd: Rc::new(UnsafeCell::new(f)), cfg1, cfg2, cfg3, _s: PhantomData}
pub fn new(
f: F, cfg1: ExtractorConfig<S, T1>, cfg2: ExtractorConfig<S, T2>,
cfg3: ExtractorConfig<S, T3>,
) -> Self {
With3 {
hnd: Rc::new(UnsafeCell::new(f)),
cfg1,
cfg2,
cfg3,
_s: PhantomData,
}
}
}
impl<T1, T2, T3, S, F, R> Handler<S> for With3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S>,
T2: FromRequest<S>,
T3: FromRequest<S>,
T1: 'static, T2: 'static, T3: 'static, S: 'static
where
F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S>,
T2: FromRequest<S>,
T3: FromRequest<S>,
T1: 'static,
T2: 'static,
T3: 'static,
S: 'static,
{
type Result = Reply;
fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
let mut fut = WithHandlerFut3{
let mut fut = WithHandlerFut3 {
req,
hnd: Rc::clone(&self.hnd),
cfg1: self.cfg1.clone(),
@ -367,12 +413,13 @@ impl<T1, T2, T3, S, F, R> Handler<S> for With3<T1, T2, T3, S, F, R>
}
struct WithHandlerFut3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static,
{
hnd: Rc<UnsafeCell<F>>,
req: HttpRequest<S>,
@ -382,26 +429,27 @@ struct WithHandlerFut3<T1, T2, T3, S, F, R>
started: bool,
item1: Option<T1>,
item2: Option<T2>,
fut1: Option<Box<Future<Item=T1, Error=Error>>>,
fut2: Option<Box<Future<Item=T2, Error=Error>>>,
fut3: Option<Box<Future<Item=T3, Error=Error>>>,
fut4: Option<Box<Future<Item=HttpResponse, Error=Error>>>,
fut1: Option<Box<Future<Item = T1, Error = Error>>>,
fut2: Option<Box<Future<Item = T2, Error = Error>>>,
fut3: Option<Box<Future<Item = T3, Error = Error>>>,
fut4: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
}
impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static
where
F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
S: 'static,
{
type Item = HttpResponse;
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut4 {
return fut.poll()
return fut.poll();
}
if !self.started {
@ -412,41 +460,43 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
let mut fut = T2::from_request(&self.req, self.cfg2.as_ref());
match fut.poll() {
Ok(Async::Ready(item2)) => {
let mut fut = T3::from_request(&self.req, self.cfg3.as_ref());
let mut fut =
T3::from_request(&self.req, self.cfg3.as_ref());
match fut.poll() {
Ok(Async::Ready(item3)) => {
let hnd: &mut F = unsafe{&mut *self.hnd.get()};
let hnd: &mut F = unsafe { &mut *self.hnd.get() };
match (*hnd)(item1, item2, item3)
.respond_to(self.req.drop_state())
{
Ok(item) => match item.into().into() {
ReplyItem::Message(resp) =>
return Ok(Async::Ready(resp)),
ReplyItem::Message(resp) => {
return Ok(Async::Ready(resp))
}
ReplyItem::Future(fut) => {
self.fut4 = Some(fut);
return self.poll()
return self.poll();
}
},
Err(e) => return Err(e.into()),
}
},
}
Ok(Async::NotReady) => {
self.item1 = Some(item1);
self.item2 = Some(item2);
self.fut3 = Some(Box::new(fut));
return Ok(Async::NotReady);
},
}
Err(e) => return Err(e),
}
},
}
Ok(Async::NotReady) => {
self.item1 = Some(item1);
self.fut2 = Some(Box::new(fut));
return Ok(Async::NotReady);
},
}
Err(e) => return Err(e),
}
},
}
Ok(Async::NotReady) => {
self.fut1 = Some(Box::new(fut));
return Ok(Async::NotReady);
@ -460,9 +510,11 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
Async::Ready(item) => {
self.item1 = Some(item);
self.fut1.take();
self.fut2 = Some(Box::new(
T2::from_request(&self.req, self.cfg2.as_ref())));
},
self.fut2 = Some(Box::new(T2::from_request(
&self.req,
self.cfg2.as_ref(),
)));
}
Async::NotReady => return Ok(Async::NotReady),
}
}
@ -472,9 +524,11 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
Async::Ready(item) => {
self.item2 = Some(item);
self.fut2.take();
self.fut3 = Some(Box::new(
T3::from_request(&self.req, self.cfg3.as_ref())));
},
self.fut3 = Some(Box::new(T3::from_request(
&self.req,
self.cfg3.as_ref(),
)));
}
Async::NotReady => return Ok(Async::NotReady),
}
}
@ -484,11 +538,12 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
Async::NotReady => return Ok(Async::NotReady),
};
let hnd: &mut F = unsafe{&mut *self.hnd.get()};
let item = match (*hnd)(self.item1.take().unwrap(),
self.item2.take().unwrap(),
item)
.respond_to(self.req.drop_state())
let hnd: &mut F = unsafe { &mut *self.hnd.get() };
let item = match (*hnd)(
self.item1.take().unwrap(),
self.item2.take().unwrap(),
item,
).respond_to(self.req.drop_state())
{
Ok(item) => item.into(),
Err(err) => return Err(err.into()),

View File

@ -1,67 +1,65 @@
//! Http client request
use std::{fmt, io, str};
use std::rc::Rc;
use std::cell::UnsafeCell;
use std::rc::Rc;
use std::time::Duration;
use std::{fmt, io, str};
use base64;
use rand;
use byteorder::{ByteOrder, NetworkEndian};
use bytes::Bytes;
use cookie::Cookie;
use byteorder::{ByteOrder, NetworkEndian};
use http::{HttpTryFrom, StatusCode, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use sha1::Sha1;
use futures::{Async, Future, Poll, Stream};
use futures::unsync::mpsc::{unbounded, UnboundedSender};
use futures::{Async, Future, Poll, Stream};
use http::header::{self, HeaderName, HeaderValue};
use http::{Error as HttpError, HttpTryFrom, StatusCode};
use rand;
use sha1::Sha1;
use actix::prelude::*;
use body::{Body, Binary};
use body::{Binary, Body};
use error::{Error, UrlParseError};
use header::IntoHeaderValue;
use payload::PayloadHelper;
use httpmessage::HttpMessage;
use payload::PayloadHelper;
use client::{ClientRequest, ClientRequestBuilder, ClientResponse,
ClientConnector, SendRequest, SendRequestError,
HttpResponseParserError};
use client::{ClientConnector, ClientRequest, ClientRequestBuilder, ClientResponse,
HttpResponseParserError, SendRequest, SendRequestError};
use super::{Message, ProtocolError};
use super::frame::Frame;
use super::proto::{CloseCode, OpCode};
use super::{Message, ProtocolError};
/// Websocket client error
#[derive(Fail, Debug)]
pub enum ClientError {
#[fail(display="Invalid url")]
#[fail(display = "Invalid url")]
InvalidUrl,
#[fail(display="Invalid response status")]
#[fail(display = "Invalid response status")]
InvalidResponseStatus(StatusCode),
#[fail(display="Invalid upgrade header")]
#[fail(display = "Invalid upgrade header")]
InvalidUpgradeHeader,
#[fail(display="Invalid connection header")]
#[fail(display = "Invalid connection header")]
InvalidConnectionHeader(HeaderValue),
#[fail(display="Missing CONNECTION header")]
#[fail(display = "Missing CONNECTION header")]
MissingConnectionHeader,
#[fail(display="Missing SEC-WEBSOCKET-ACCEPT header")]
#[fail(display = "Missing SEC-WEBSOCKET-ACCEPT header")]
MissingWebSocketAcceptHeader,
#[fail(display="Invalid challenge response")]
#[fail(display = "Invalid challenge response")]
InvalidChallengeResponse(String, HeaderValue),
#[fail(display="Http parsing error")]
#[fail(display = "Http parsing error")]
Http(Error),
#[fail(display="Url parsing error")]
#[fail(display = "Url parsing error")]
Url(UrlParseError),
#[fail(display="Response parsing error")]
#[fail(display = "Response parsing error")]
ResponseParseError(HttpResponseParserError),
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
SendRequest(SendRequestError),
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Protocol(#[cause] ProtocolError),
#[fail(display="{}", _0)]
#[fail(display = "{}", _0)]
Io(io::Error),
#[fail(display="Disconnected")]
#[fail(display = "Disconnected")]
Disconnected,
}
@ -117,14 +115,15 @@ pub struct Client {
}
impl Client {
/// Create new websocket connection
pub fn new<S: AsRef<str>>(uri: S) -> Client {
Client::with_connector(uri, ClientConnector::from_registry())
}
/// Create new websocket connection with custom `ClientConnector`
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> Client {
pub fn with_connector<S: AsRef<str>>(
uri: S, conn: Addr<Unsync, ClientConnector>
) -> Client {
let mut cl = Client {
request: ClientRequest::build(),
err: None,
@ -140,11 +139,13 @@ impl Client {
/// Set supported websocket protocols
pub fn protocols<U, V>(mut self, protos: U) -> Self
where U: IntoIterator<Item=V> + 'static,
V: AsRef<str>
where
U: IntoIterator<Item = V> + 'static,
V: AsRef<str>,
{
let mut protos = protos.into_iter()
.fold(String::new(), |acc, s| {acc + s.as_ref() + ","});
let mut protos = protos
.into_iter()
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
protos.pop();
self.protocols = Some(protos);
self
@ -158,7 +159,8 @@ impl Client {
/// Set request Origin
pub fn origin<V>(mut self, origin: V) -> Self
where HeaderValue: HttpTryFrom<V>
where
HeaderValue: HttpTryFrom<V>,
{
match HeaderValue::try_from(origin) {
Ok(value) => self.origin = Some(value),
@ -185,7 +187,9 @@ impl Client {
/// Set request header
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
self.request.header(key, value);
self
@ -204,8 +208,7 @@ impl Client {
pub fn connect(&mut self) -> ClientHandshake {
if let Some(e) = self.err.take() {
ClientHandshake::error(e)
}
else if let Some(e) = self.http_err.take() {
} else if let Some(e) = self.http_err.take() {
ClientHandshake::error(Error::from(e).into())
} else {
// origin
@ -216,11 +219,13 @@ impl Client {
self.request.upgrade();
self.request.set_header(header::UPGRADE, "websocket");
self.request.set_header(header::CONNECTION, "upgrade");
self.request.set_header(header::SEC_WEBSOCKET_VERSION, "13");
self.request
.set_header(header::SEC_WEBSOCKET_VERSION, "13");
self.request.with_connector(self.conn.clone());
if let Some(protocols) = self.protocols.take() {
self.request.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str());
self.request
.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str());
}
let request = match self.request.finish() {
Ok(req) => req,
@ -228,14 +233,16 @@ impl Client {
};
if request.uri().host().is_none() {
return ClientHandshake::error(ClientError::InvalidUrl)
return ClientHandshake::error(ClientError::InvalidUrl);
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return ClientHandshake::error(ClientError::InvalidUrl)
if scheme != "http" && scheme != "https" && scheme != "ws"
&& scheme != "wss"
{
return ClientHandshake::error(ClientError::InvalidUrl);
}
} else {
return ClientHandshake::error(ClientError::InvalidUrl)
return ClientHandshake::error(ClientError::InvalidUrl);
}
// start handshake
@ -263,8 +270,7 @@ pub struct ClientHandshake {
}
impl ClientHandshake {
fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake
{
fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake {
// Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455)
@ -273,12 +279,13 @@ impl ClientHandshake {
request.headers_mut().insert(
header::SEC_WEBSOCKET_KEY,
HeaderValue::try_from(key.as_str()).unwrap());
HeaderValue::try_from(key.as_str()).unwrap(),
);
let (tx, rx) = unbounded();
request.set_body(Body::Streaming(
Box::new(rx.map_err(|_| io::Error::new(
io::ErrorKind::Other, "disconnected").into()))));
request.set_body(Body::Streaming(Box::new(rx.map_err(|_| {
io::Error::new(io::ErrorKind::Other, "disconnected").into()
}))));
ClientHandshake {
key,
@ -329,20 +336,20 @@ impl Future for ClientHandshake {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.error.take() {
return Err(err)
return Err(err);
}
let resp = match self.request.as_mut().unwrap().poll()? {
Async::Ready(response) => {
self.request.take();
response
},
Async::NotReady => return Ok(Async::NotReady)
}
Async::NotReady => return Ok(Async::NotReady),
};
// verify response
if resp.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(ClientError::InvalidResponseStatus(resp.status()))
return Err(ClientError::InvalidResponseStatus(resp.status()));
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) {
@ -356,26 +363,25 @@ impl Future for ClientHandshake {
};
if !has_hdr {
trace!("Invalid upgrade header");
return Err(ClientError::InvalidUpgradeHeader)
return Err(ClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = resp.headers().get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_lowercase().contains("upgrade") {
trace!("Invalid connection header: {}", s);
return Err(ClientError::InvalidConnectionHeader(conn.clone()))
return Err(ClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
trace!("Invalid connection header: {:?}", conn);
return Err(ClientError::InvalidConnectionHeader(conn.clone()))
return Err(ClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
trace!("Missing connection header");
return Err(ClientError::MissingConnectionHeader)
return Err(ClientError::MissingConnectionHeader);
}
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT)
{
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT) {
// field is constructed by concatenating /key/
// with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@ -386,12 +392,17 @@ impl Future for ClientHandshake {
if key.as_bytes() != encoded.as_bytes() {
trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded, key);
return Err(ClientError::InvalidChallengeResponse(encoded, key.clone()));
encoded,
key
);
return Err(ClientError::InvalidChallengeResponse(
encoded,
key.clone(),
));
}
} else {
trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(ClientError::MissingWebSocketAcceptHeader)
return Err(ClientError::MissingWebSocketAcceptHeader);
};
let inner = Inner {
@ -401,13 +412,16 @@ impl Future for ClientHandshake {
};
let inner = Rc::new(UnsafeCell::new(inner));
Ok(Async::Ready(
(ClientReader{inner: Rc::clone(&inner), max_size: self.max_size},
ClientWriter{inner})))
Ok(Async::Ready((
ClientReader {
inner: Rc::clone(&inner),
max_size: self.max_size,
},
ClientWriter { inner },
)))
}
}
pub struct ClientReader {
inner: Rc<UnsafeCell<Inner>>,
max_size: usize,
@ -422,7 +436,7 @@ impl fmt::Debug for ClientReader {
impl ClientReader {
#[inline]
fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() }
unsafe { &mut *self.inner.get() }
}
}
@ -434,7 +448,7 @@ impl Stream for ClientReader {
let max_size = self.max_size;
let inner = self.as_mut();
if inner.closed {
return Ok(Async::Ready(None))
return Ok(Async::Ready(None));
}
// read
@ -447,31 +461,29 @@ impl Stream for ClientReader {
OpCode::Continue => {
inner.closed = true;
Err(ProtocolError::NoContinuation)
},
}
OpCode::Bad => {
inner.closed = true;
Err(ProtocolError::BadOpCode)
},
}
OpCode::Close => {
inner.closed = true;
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
Ok(Async::Ready(Some(Message::Close(CloseCode::from(code)))))
},
OpCode::Ping =>
Ok(Async::Ready(Some(
Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into())))),
OpCode::Pong =>
Ok(Async::Ready(Some(
Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into())))),
OpCode::Binary =>
Ok(Async::Ready(Some(Message::Binary(payload)))),
Ok(Async::Ready(Some(Message::Close(CloseCode::from(
code,
)))))
}
OpCode::Ping => Ok(Async::Ready(Some(Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into(),
)))),
OpCode::Pong => Ok(Async::Ready(Some(Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into(),
)))),
OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))),
OpCode::Text => {
let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) {
Ok(s) =>
Ok(Async::Ready(Some(Message::Text(s)))),
Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => {
inner.closed = true;
Err(ProtocolError::BadEncoding)
@ -491,18 +503,17 @@ impl Stream for ClientReader {
}
pub struct ClientWriter {
inner: Rc<UnsafeCell<Inner>>
inner: Rc<UnsafeCell<Inner>>,
}
impl ClientWriter {
#[inline]
fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() }
unsafe { &mut *self.inner.get() }
}
}
impl ClientWriter {
/// Write payload
#[inline]
fn write(&mut self, mut data: Binary) {
@ -528,13 +539,23 @@ impl ClientWriter {
/// Send ping frame
#[inline]
pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true, true));
self.write(Frame::message(
Vec::from(message),
OpCode::Ping,
true,
true,
));
}
/// Send pong frame
#[inline]
pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true, true));
self.write(Frame::message(
Vec::from(message),
OpCode::Pong,
true,
true,
));
}
/// Send close frame

View File

@ -1,25 +1,26 @@
use std::mem;
use futures::{Async, Poll};
use futures::sync::oneshot::Sender;
use futures::unsync::oneshot;
use futures::{Async, Poll};
use smallvec::SmallVec;
use std::mem;
use actix::{Actor, ActorState, ActorContext, AsyncContext,
Addr, Handler, Message, Syn, Unsync, SpawnHandle};
use actix::dev::{ContextImpl, SyncEnvelope, ToEnvelope};
use actix::fut::ActorFuture;
use actix::dev::{ContextImpl, ToEnvelope, SyncEnvelope};
use actix::{Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message,
SpawnHandle, Syn, Unsync};
use body::{Body, Binary};
use body::{Binary, Body};
use context::{ActorHttpContext, Drain, Frame as ContextFrame};
use error::{Error, ErrorInternalServerError};
use httprequest::HttpRequest;
use context::{Frame as ContextFrame, ActorHttpContext, Drain};
use ws::frame::Frame;
use ws::proto::{OpCode, CloseCode};
use ws::proto::{CloseCode, OpCode};
/// Execution context for `WebSockets` actors
pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A, S>>,
pub struct WebsocketContext<A, S = ()>
where
A: Actor<Context = WebsocketContext<A, S>>,
{
inner: ContextImpl<A>,
stream: Option<SmallVec<[ContextFrame; 4]>>,
@ -27,7 +28,9 @@ pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A,
disconnected: bool,
}
impl<A, S> ActorContext for WebsocketContext<A, S> where A: Actor<Context=Self>
impl<A, S> ActorContext for WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{
fn stop(&mut self) {
self.inner.stop();
@ -40,16 +43,20 @@ impl<A, S> ActorContext for WebsocketContext<A, S> where A: Actor<Context=Self>
}
}
impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Self>
impl<A, S> AsyncContext<A> for WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{
fn spawn<F>(&mut self, fut: F) -> SpawnHandle
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{
self.inner.spawn(fut)
}
fn wait<F>(&mut self, fut: F)
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static
where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{
self.inner.wait(fut)
}
@ -57,8 +64,8 @@ impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Sel
#[doc(hidden)]
#[inline]
fn waiting(&self) -> bool {
self.inner.waiting() || self.inner.state() == ActorState::Stopping ||
self.inner.state() == ActorState::Stopped
self.inner.waiting() || self.inner.state() == ActorState::Stopping
|| self.inner.state() == ActorState::Stopped
}
fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
@ -78,8 +85,10 @@ impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Sel
}
}
impl<A, S: 'static> WebsocketContext<A, S> where A: Actor<Context=Self> {
impl<A, S: 'static> WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{
#[inline]
pub fn new(req: HttpRequest<S>, actor: A) -> WebsocketContext<A, S> {
WebsocketContext::from_request(req).actor(actor)
@ -101,8 +110,10 @@ impl<A, S: 'static> WebsocketContext<A, S> where A: Actor<Context=Self> {
}
}
impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
impl<A, S> WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{
/// Write payload
#[inline]
fn write(&mut self, data: Binary) {
@ -145,13 +156,23 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Send ping frame
#[inline]
pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true, false));
self.write(Frame::message(
Vec::from(message),
OpCode::Ping,
true,
false,
));
}
/// Send pong frame
#[inline]
pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true, false));
self.write(Frame::message(
Vec::from(message),
OpCode::Pong,
true,
false,
));
}
/// Send close frame
@ -191,8 +212,11 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
}
}
impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Self>, S: 'static {
impl<A, S> ActorHttpContext for WebsocketContext<A, S>
where
A: Actor<Context = Self>,
S: 'static,
{
#[inline]
fn disconnected(&mut self) {
self.disconnected = true;
@ -200,12 +224,11 @@ impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Se
}
fn poll(&mut self) -> Poll<Option<SmallVec<[ContextFrame; 4]>>, Error> {
let ctx: &mut WebsocketContext<A, S> = unsafe {
mem::transmute(self as &mut WebsocketContext<A, S>)
};
let ctx: &mut WebsocketContext<A, S> =
unsafe { mem::transmute(self as &mut WebsocketContext<A, S>) };
if self.inner.alive() && self.inner.poll(ctx).is_err() {
return Err(ErrorInternalServerError("error"))
return Err(ErrorInternalServerError("error"));
}
// frames
@ -220,8 +243,10 @@ impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Se
}
impl<A, M, S> ToEnvelope<Syn, A, M> for WebsocketContext<A, S>
where A: Actor<Context=WebsocketContext<A, S>> + Handler<M>,
M: Message + Send + 'static, M::Result: Send
where
A: Actor<Context = WebsocketContext<A, S>> + Handler<M>,
M: Message + Send + 'static,
M::Result: Send,
{
fn pack(msg: M, tx: Option<Sender<M::Result>>) -> SyncEnvelope<A> {
SyncEnvelope::new(msg, tx)
@ -229,7 +254,9 @@ impl<A, M, S> ToEnvelope<Syn, A, M> for WebsocketContext<A, S>
}
impl<A, S> From<WebsocketContext<A, S>> for Body
where A: Actor<Context=WebsocketContext<A, S>>, S: 'static
where
A: Actor<Context = WebsocketContext<A, S>>,
S: 'static,
{
fn from(ctx: WebsocketContext<A, S>) -> Body {
Body::Actor(Box::new(ctx))

View File

@ -1,17 +1,17 @@
use std::{fmt, mem, ptr};
use std::iter::FromIterator;
use bytes::{Bytes, BytesMut, BufMut};
use byteorder::{ByteOrder, BigEndian, NetworkEndian};
use byteorder::{BigEndian, ByteOrder, NetworkEndian};
use bytes::{BufMut, Bytes, BytesMut};
use futures::{Async, Poll, Stream};
use rand;
use std::iter::FromIterator;
use std::{fmt, mem, ptr};
use body::Binary;
use error::{PayloadError};
use error::PayloadError;
use payload::PayloadHelper;
use ws::ProtocolError;
use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask;
use ws::proto::{CloseCode, OpCode};
/// A struct representing a `WebSocket` frame.
#[derive(Debug)]
@ -22,7 +22,6 @@ pub struct Frame {
}
impl Frame {
/// Destruct frame
pub fn unpack(self) -> (bool, OpCode, Binary) {
(self.finished, self.opcode, self.payload)
@ -40,20 +39,22 @@ impl Frame {
Vec::new()
} else {
Vec::from_iter(
raw[..].iter()
raw[..]
.iter()
.chain(reason.as_bytes().iter())
.cloned())
.cloned(),
)
};
Frame::message(payload, OpCode::Close, true, genmask)
}
#[cfg_attr(feature="cargo-clippy", allow(type_complexity))]
fn read_copy_md<S>(pl: &mut PayloadHelper<S>,
server: bool,
max_size: usize
#[cfg_attr(feature = "cargo-clippy", allow(type_complexity))]
fn read_copy_md<S>(
pl: &mut PayloadHelper<S>, server: bool, max_size: usize
) -> Poll<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError>
where S: Stream<Item=Bytes, Error=PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
let mut idx = 2;
let buf = match pl.copy(2)? {
@ -68,16 +69,16 @@ impl Frame {
// check masking
let masked = second & 0x80 != 0;
if !masked && server {
return Err(ProtocolError::UnmaskedFrame)
return Err(ProtocolError::UnmaskedFrame);
} else if masked && !server {
return Err(ProtocolError::MaskedFrame)
return Err(ProtocolError::MaskedFrame);
}
// Op code
let opcode = OpCode::from(first & 0x0F);
if let OpCode::Bad = opcode {
return Err(ProtocolError::InvalidOpcode(first & 0x0F))
return Err(ProtocolError::InvalidOpcode(first & 0x0F));
}
let len = second & 0x7F;
@ -105,7 +106,7 @@ impl Frame {
// check for max allowed size
if length > max_size {
return Err(ProtocolError::Overflow)
return Err(ProtocolError::Overflow);
}
let mask = if server {
@ -115,25 +116,32 @@ impl Frame {
Async::NotReady => return Ok(Async::NotReady),
};
let mask: &[u8] = &buf[idx..idx+4];
let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
let mask: &[u8] = &buf[idx..idx + 4];
let mask_u32: u32 =
unsafe { ptr::read_unaligned(mask.as_ptr() as *const u32) };
idx += 4;
Some(mask_u32)
} else {
None
};
Ok(Async::Ready(Some((idx, finished, opcode, length, mask))))
Ok(Async::Ready(Some((
idx,
finished,
opcode,
length,
mask,
))))
}
fn read_chunk_md(chunk: &[u8], server: bool, max_size: usize)
-> Poll<(usize, bool, OpCode, usize, Option<u32>), ProtocolError>
{
fn read_chunk_md(
chunk: &[u8], server: bool, max_size: usize
) -> Poll<(usize, bool, OpCode, usize, Option<u32>), ProtocolError> {
let chunk_len = chunk.len();
let mut idx = 2;
if chunk_len < 2 {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
let first = chunk[0];
@ -143,29 +151,29 @@ impl Frame {
// check masking
let masked = second & 0x80 != 0;
if !masked && server {
return Err(ProtocolError::UnmaskedFrame)
return Err(ProtocolError::UnmaskedFrame);
} else if masked && !server {
return Err(ProtocolError::MaskedFrame)
return Err(ProtocolError::MaskedFrame);
}
// Op code
let opcode = OpCode::from(first & 0x0F);
if let OpCode::Bad = opcode {
return Err(ProtocolError::InvalidOpcode(first & 0x0F))
return Err(ProtocolError::InvalidOpcode(first & 0x0F));
}
let len = second & 0x7F;
let length = if len == 126 {
if chunk_len < 4 {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize;
idx += 2;
len
} else if len == 127 {
if chunk_len < 10 {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
let len = NetworkEndian::read_uint(&chunk[idx..], 8) as usize;
idx += 8;
@ -176,16 +184,17 @@ impl Frame {
// check for max allowed size
if length > max_size {
return Err(ProtocolError::Overflow)
return Err(ProtocolError::Overflow);
}
let mask = if server {
if chunk_len < idx + 4 {
return Ok(Async::NotReady)
return Ok(Async::NotReady);
}
let mask: &[u8] = &chunk[idx..idx+4];
let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
let mask: &[u8] = &chunk[idx..idx + 4];
let mask_u32: u32 =
unsafe { ptr::read_unaligned(mask.as_ptr() as *const u32) };
idx += 4;
Some(mask_u32)
} else {
@ -196,9 +205,11 @@ impl Frame {
}
/// Parse the input stream into a frame.
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize)
-> Poll<Option<Frame>, ProtocolError>
where S: Stream<Item=Bytes, Error=PayloadError>
pub fn parse<S>(
pl: &mut PayloadHelper<S>, server: bool, max_size: usize
) -> Poll<Option<Frame>, ProtocolError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
// try to parse ws frame md from one chunk
let result = match pl.get_chunk()? {
@ -229,7 +240,10 @@ impl Frame {
// no need for body
if length == 0 {
return Ok(Async::Ready(Some(Frame {
finished, opcode, payload: Binary::from("") })));
finished,
opcode,
payload: Binary::from(""),
})));
}
let data = match pl.read_exact(length)? {
@ -245,26 +259,32 @@ impl Frame {
}
OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
return Ok(Async::Ready(Some(Frame::default())))
return Ok(Async::Ready(Some(Frame::default())));
}
_ => ()
_ => (),
}
// unmask
if let Some(mask) = mask {
#[allow(mutable_transmutes)]
let p: &mut [u8] = unsafe{let ptr: &[u8] = &data; mem::transmute(ptr)};
let p: &mut [u8] = unsafe {
let ptr: &[u8] = &data;
mem::transmute(ptr)
};
apply_mask(p, mask);
}
Ok(Async::Ready(Some(Frame {
finished, opcode, payload: data.into() })))
finished,
opcode,
payload: data.into(),
})))
}
/// Generate binary representation
pub fn message<B: Into<Binary>>(data: B, code: OpCode,
finished: bool, genmask: bool) -> Binary
{
pub fn message<B: Into<Binary>>(
data: B, code: OpCode, finished: bool, genmask: bool
) -> Binary {
let payload = data.into();
let one: u8 = if finished {
0x80 | Into::<u8>::into(code)
@ -286,19 +306,19 @@ impl Frame {
let mut buf = BytesMut::with_capacity(p_len + 4);
buf.put_slice(&[one, two | 126]);
{
let buf_mut = unsafe{buf.bytes_mut()};
let buf_mut = unsafe { buf.bytes_mut() };
BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16);
}
unsafe{buf.advance_mut(2)};
unsafe { buf.advance_mut(2) };
buf
} else {
let mut buf = BytesMut::with_capacity(p_len + 10);
buf.put_slice(&[one, two | 127]);
{
let buf_mut = unsafe{buf.bytes_mut()};
let buf_mut = unsafe { buf.bytes_mut() };
BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64);
}
unsafe{buf.advance_mut(8)};
unsafe { buf.advance_mut(8) };
buf
};
@ -308,7 +328,7 @@ impl Frame {
{
let buf_mut = buf.bytes_mut();
*(buf_mut as *mut _ as *mut u32) = mask;
buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref());
buf_mut[4..payload_len + 4].copy_from_slice(payload.as_ref());
apply_mask(&mut buf_mut[4..], mask);
}
buf.advance_mut(payload_len + 4);
@ -333,7 +353,8 @@ impl Default for Frame {
impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f,
write!(
f,
"
<FRAME>
final: {}
@ -341,11 +362,15 @@ impl fmt::Display for Frame {
payload length: {}
payload: 0x{}
</FRAME>",
self.finished,
self.opcode,
self.payload.len(),
self.payload.as_ref().iter().map(
|byte| format!("{:x}", byte)).collect::<String>())
self.finished,
self.opcode,
self.payload.len(),
self.payload
.as_ref()
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
)
}
}
@ -360,7 +385,7 @@ mod tests {
_ => false,
}
}
fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame {
match frm {
Ok(Async::Ready(Some(frame))) => frame,
@ -370,8 +395,9 @@ mod tests {
#[test]
fn test_parse() {
let mut buf = PayloadHelper::new(
once(Ok(BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]).freeze())));
let mut buf = PayloadHelper::new(once(Ok(BytesMut::from(
&[0b0000_0001u8, 0b0000_0001u8][..],
).freeze())));
assert!(is_none(&Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);

View File

@ -20,7 +20,7 @@ fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) {
/// Faster version of `apply_mask()` which operates on 8-byte blocks.
#[inline]
#[cfg_attr(feature="cargo-clippy", allow(cast_lossless))]
#[cfg_attr(feature = "cargo-clippy", allow(cast_lossless))]
fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) {
let mut ptr = buf.as_mut_ptr();
let mut len = buf.len();
@ -85,13 +85,16 @@ fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) {
// Possible last block.
if len > 0 {
unsafe { xor_mem(ptr, mask_u32, len); }
unsafe {
xor_mem(ptr, mask_u32, len);
}
}
}
#[inline]
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient,
// it could be done better. The compiler does not see that len is limited to 3.
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so
// inefficient, it could be done better. The compiler does not see that len is
// limited to 3.
unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
let mut b: u32 = uninitialized();
#[allow(trivial_casts)]
@ -103,19 +106,17 @@ unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
#[cfg(test)]
mod tests {
use super::{apply_mask_fallback, apply_mask_fast32};
use std::ptr;
use super::{apply_mask_fallback, apply_mask_fast32};
#[test]
fn test_apply_mask() {
let mask = [
0x6d, 0xb6, 0xb2, 0x80,
];
let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
let mask = [0x6d, 0xb6, 0xb2, 0x80];
let mask_u32: u32 = unsafe { ptr::read_unaligned(mask.as_ptr() as *const u32) };
let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82,
0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03,
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17,
0x74, 0xf9, 0x12, 0x03,
];
// Check masking with proper alignment.

View File

@ -1,7 +1,8 @@
//! `WebSocket` support for Actix
//!
//! To setup a `WebSocket`, first do web socket handshake then on success convert `Payload`
//! into a `WsStream` stream and then use `WsWriter` to communicate with the peer.
//! To setup a `WebSocket`, first do web socket handshake then on success
//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
//! communicate with the peer.
//!
//! ## Example
//!
@ -42,62 +43,61 @@
//! # .finish();
//! # }
//! ```
use bytes::Bytes;
use http::{Method, StatusCode, header};
use futures::{Async, Poll, Stream};
use byteorder::{ByteOrder, NetworkEndian};
use bytes::Bytes;
use futures::{Async, Poll, Stream};
use http::{header, Method, StatusCode};
use actix::{Actor, AsyncContext, StreamHandler};
use body::Binary;
use payload::PayloadHelper;
use error::{Error, PayloadError, ResponseError};
use httpmessage::HttpMessage;
use httprequest::HttpRequest;
use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder};
use payload::PayloadHelper;
mod frame;
mod proto;
mod context;
mod mask;
mod client;
mod context;
mod frame;
mod mask;
mod proto;
pub use self::frame::Frame;
pub use self::proto::OpCode;
pub use self::proto::CloseCode;
pub use self::client::{Client, ClientError, ClientHandshake, ClientReader, ClientWriter};
pub use self::context::WebsocketContext;
pub use self::client::{Client, ClientError,
ClientReader, ClientWriter, ClientHandshake};
pub use self::frame::Frame;
pub use self::proto::CloseCode;
pub use self::proto::OpCode;
/// Websocket protocol errors
#[derive(Fail, Debug)]
pub enum ProtocolError {
/// Received an unmasked frame from client
#[fail(display="Received an unmasked frame from client")]
#[fail(display = "Received an unmasked frame from client")]
UnmaskedFrame,
/// Received a masked frame from server
#[fail(display="Received a masked frame from server")]
#[fail(display = "Received a masked frame from server")]
MaskedFrame,
/// Encountered invalid opcode
#[fail(display="Invalid opcode: {}", _0)]
#[fail(display = "Invalid opcode: {}", _0)]
InvalidOpcode(u8),
/// Invalid control frame length
#[fail(display="Invalid control frame length: {}", _0)]
#[fail(display = "Invalid control frame length: {}", _0)]
InvalidLength(usize),
/// Bad web socket op code
#[fail(display="Bad web socket op code")]
#[fail(display = "Bad web socket op code")]
BadOpCode,
/// A payload reached size limit.
#[fail(display="A payload reached size limit.")]
#[fail(display = "A payload reached size limit.")]
Overflow,
/// Continuation is not supported
#[fail(display="Continuation is not supported.")]
#[fail(display = "Continuation is not supported.")]
NoContinuation,
/// Bad utf-8 encoding
#[fail(display="Bad utf-8 encoding.")]
#[fail(display = "Bad utf-8 encoding.")]
BadEncoding,
/// Payload error
#[fail(display="Payload error: {}", _0)]
#[fail(display = "Payload error: {}", _0)]
Payload(#[cause] PayloadError),
}
@ -113,42 +113,46 @@ impl From<PayloadError> for ProtocolError {
#[derive(Fail, PartialEq, Debug)]
pub enum HandshakeError {
/// Only get method is allowed
#[fail(display="Method not allowed")]
#[fail(display = "Method not allowed")]
GetMethodRequired,
/// Upgrade header if not set to websocket
#[fail(display="Websocket upgrade is expected")]
#[fail(display = "Websocket upgrade is expected")]
NoWebsocketUpgrade,
/// Connection header is not set to upgrade
#[fail(display="Connection upgrade is expected")]
#[fail(display = "Connection upgrade is expected")]
NoConnectionUpgrade,
/// Websocket version header is not set
#[fail(display="Websocket version header is required")]
#[fail(display = "Websocket version header is required")]
NoVersionHeader,
/// Unsupported websocket version
#[fail(display="Unsupported version")]
#[fail(display = "Unsupported version")]
UnsupportedVersion,
/// Websocket key is not set or wrong
#[fail(display="Unknown websocket key")]
#[fail(display = "Unknown websocket key")]
BadWebsocketKey,
}
impl ResponseError for HandshakeError {
fn error_response(&self) -> HttpResponse {
match *self {
HandshakeError::GetMethodRequired => {
HttpResponse::MethodNotAllowed().header(header::ALLOW, "GET").finish()
}
HandshakeError::GetMethodRequired => HttpResponse::MethodNotAllowed()
.header(header::ALLOW, "GET")
.finish(),
HandshakeError::NoWebsocketUpgrade => HttpResponse::BadRequest()
.reason("No WebSocket UPGRADE header found").finish(),
.reason("No WebSocket UPGRADE header found")
.finish(),
HandshakeError::NoConnectionUpgrade => HttpResponse::BadRequest()
.reason("No CONNECTION upgrade").finish(),
.reason("No CONNECTION upgrade")
.finish(),
HandshakeError::NoVersionHeader => HttpResponse::BadRequest()
.reason("Websocket version header is required").finish(),
.reason("Websocket version header is required")
.finish(),
HandshakeError::UnsupportedVersion => HttpResponse::BadRequest()
.reason("Unsupported version").finish(),
.reason("Unsupported version")
.finish(),
HandshakeError::BadWebsocketKey => HttpResponse::BadRequest()
.reason("Handshake error").finish(),
.reason("Handshake error")
.finish(),
}
}
}
@ -165,8 +169,9 @@ pub enum Message {
/// Do websocket handshake and start actor
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>,
S: 'static
where
A: Actor<Context = WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>,
S: 'static,
{
let mut resp = handshake(&req)?;
let stream = WsStream::new(req.clone());
@ -185,10 +190,12 @@ pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, HandshakeError> {
pub fn handshake<S>(
req: &HttpRequest<S>
) -> Result<HttpResponseBuilder, HandshakeError> {
// WebSocket accepts only GET
if *req.method() != Method::GET {
return Err(HandshakeError::GetMethodRequired)
return Err(HandshakeError::GetMethodRequired);
}
// Check for "UPGRADE" to websocket header
@ -202,17 +209,19 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, Handsha
false
};
if !has_hdr {
return Err(HandshakeError::NoWebsocketUpgrade)
return Err(HandshakeError::NoWebsocketUpgrade);
}
// Upgrade connection
if !req.upgrade() {
return Err(HandshakeError::NoConnectionUpgrade)
return Err(HandshakeError::NoConnectionUpgrade);
}
// check supported version
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
return Err(HandshakeError::NoVersionHeader)
if !req.headers()
.contains_key(header::SEC_WEBSOCKET_VERSION)
{
return Err(HandshakeError::NoVersionHeader);
}
let supported_ver = {
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
@ -222,12 +231,12 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, Handsha
}
};
if !supported_ver {
return Err(HandshakeError::UnsupportedVersion)
return Err(HandshakeError::UnsupportedVersion);
}
// check client handshake for validity
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
return Err(HandshakeError::BadWebsocketKey)
return Err(HandshakeError::BadWebsocketKey);
}
let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
@ -235,11 +244,11 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, Handsha
};
Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
.connection_type(ConnectionType::Upgrade)
.header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take())
.connection_type(ConnectionType::Upgrade)
.header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take())
}
/// Maps `Payload` stream into stream of `ws::Message` items
@ -249,12 +258,16 @@ pub struct WsStream<S> {
max_size: usize,
}
impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> WsStream<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create new websocket frames stream
pub fn new(stream: S) -> WsStream<S> {
WsStream { rx: PayloadHelper::new(stream),
closed: false,
max_size: 65_536,
WsStream {
rx: PayloadHelper::new(stream),
closed: false,
max_size: 65_536,
}
}
@ -267,13 +280,16 @@ impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
}
}
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
impl<S> Stream for WsStream<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = Message;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.closed {
return Ok(Async::Ready(None))
return Ok(Async::Ready(None));
}
match Frame::parse(&mut self.rx, true, self.max_size) {
@ -283,7 +299,7 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// continuation is not supported
if !finished {
self.closed = true;
return Err(ProtocolError::NoContinuation)
return Err(ProtocolError::NoContinuation);
}
match opcode {
@ -295,23 +311,21 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
OpCode::Close => {
self.closed = true;
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
Ok(Async::Ready(
Some(Message::Close(CloseCode::from(code)))))
},
OpCode::Ping =>
Ok(Async::Ready(Some(
Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into())))),
OpCode::Pong =>
Ok(Async::Ready(Some(
Message::Pong(String::from_utf8_lossy(payload.as_ref()).into())))),
OpCode::Binary =>
Ok(Async::Ready(Some(Message::Binary(payload)))),
Ok(Async::Ready(Some(Message::Close(CloseCode::from(
code,
)))))
}
OpCode::Ping => Ok(Async::Ready(Some(Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into(),
)))),
OpCode::Pong => Ok(Async::Ready(Some(Message::Pong(
String::from_utf8_lossy(payload.as_ref()).into(),
)))),
OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))),
OpCode::Text => {
let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) {
Ok(s) =>
Ok(Async::Ready(Some(Message::Text(s)))),
Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => {
self.closed = true;
Err(ProtocolError::BadEncoding)
@ -333,77 +347,168 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
#[cfg(test)]
mod tests {
use super::*;
use http::{header, HeaderMap, Method, Uri, Version};
use std::str::FromStr;
use http::{Method, HeaderMap, Version, Uri, header};
#[test]
fn test_handshake() {
let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap());
let req = HttpRequest::new(
Method::POST,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert_eq!(
HandshakeError::GetMethodRequired,
handshake(&req).err().unwrap()
);
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, HeaderMap::new(), None);
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("test"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap());
headers.insert(
header::UPGRADE,
header::HeaderValue::from_static("test"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap());
headers.insert(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::NoConnectionUpgrade,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket"));
headers.insert(header::CONNECTION,
header::HeaderValue::from_static("upgrade"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap());
headers.insert(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
);
headers.insert(
header::CONNECTION,
header::HeaderValue::from_static("upgrade"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::NoVersionHeader,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket"));
headers.insert(header::CONNECTION,
header::HeaderValue::from_static("upgrade"));
headers.insert(header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("5"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap());
headers.insert(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
);
headers.insert(
header::CONNECTION,
header::HeaderValue::from_static("upgrade"),
);
headers.insert(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("5"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::UnsupportedVersion,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket"));
headers.insert(header::CONNECTION,
header::HeaderValue::from_static("upgrade"));
headers.insert(header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap());
headers.insert(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
);
headers.insert(
header::CONNECTION,
header::HeaderValue::from_static("upgrade"),
);
headers.insert(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::BadWebsocketKey,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE,
header::HeaderValue::from_static("websocket"));
headers.insert(header::CONNECTION,
header::HeaderValue::from_static("upgrade"));
headers.insert(header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"));
headers.insert(header::SEC_WEBSOCKET_KEY,
header::HeaderValue::from_static("13"));
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(),
Version::HTTP_11, headers, None);
assert_eq!(StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().status());
headers.insert(
header::UPGRADE,
header::HeaderValue::from_static("websocket"),
);
headers.insert(
header::CONNECTION,
header::HeaderValue::from_static("upgrade"),
);
headers.insert(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"),
);
headers.insert(
header::SEC_WEBSOCKET_KEY,
header::HeaderValue::from_static("13"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().status()
);
}
#[test]

View File

@ -1,7 +1,7 @@
use std::fmt;
use std::convert::{Into, From};
use sha1;
use base64;
use sha1;
use std::convert::{From, Into};
use std::fmt;
use self::OpCode::*;
/// Operation codes as part of rfc6455.
@ -26,52 +26,54 @@ pub enum OpCode {
impl fmt::Display for OpCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Continue => write!(f, "CONTINUE"),
Text => write!(f, "TEXT"),
Binary => write!(f, "BINARY"),
Close => write!(f, "CLOSE"),
Ping => write!(f, "PING"),
Pong => write!(f, "PONG"),
Bad => write!(f, "BAD"),
Continue => write!(f, "CONTINUE"),
Text => write!(f, "TEXT"),
Binary => write!(f, "BINARY"),
Close => write!(f, "CLOSE"),
Ping => write!(f, "PING"),
Pong => write!(f, "PONG"),
Bad => write!(f, "BAD"),
}
}
}
impl Into<u8> for OpCode {
fn into(self) -> u8 {
match self {
Continue => 0,
Text => 1,
Binary => 2,
Close => 8,
Ping => 9,
Pong => 10,
Bad => {
debug_assert!(false, "Attempted to convert invalid opcode to u8. This is a bug.");
8 // if this somehow happens, a close frame will help us tear down quickly
Continue => 0,
Text => 1,
Binary => 2,
Close => 8,
Ping => 9,
Pong => 10,
Bad => {
debug_assert!(
false,
"Attempted to convert invalid opcode to u8. This is a bug."
);
8 // if this somehow happens, a close frame will help us tear down quickly
}
}
}
}
impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode {
match byte {
0 => Continue,
1 => Text,
2 => Binary,
8 => Close,
9 => Ping,
10 => Pong,
_ => Bad
0 => Continue,
1 => Text,
2 => Binary,
8 => Close,
9 => Ping,
10 => Pong,
_ => Bad,
}
}
}
use self::CloseCode::*;
/// Status code used to indicate why an endpoint is closing the `WebSocket` connection.
/// Status code used to indicate why an endpoint is closing the `WebSocket`
/// connection.
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum CloseCode {
/// Indicates a normal closure, meaning that the purpose for
@ -125,12 +127,13 @@ pub enum CloseCode {
/// it encountered an unexpected condition that prevented it from
/// fulfilling the request.
Error,
/// Indicates that the server is restarting. A client may choose to reconnect,
/// and if it does, it should use a randomized delay of 5-30 seconds between attempts.
/// Indicates that the server is restarting. A client may choose to
/// reconnect, and if it does, it should use a randomized delay of 5-30
/// seconds between attempts.
Restart,
/// Indicates that the server is overloaded and the client should either connect
/// to a different IP (when multiple targets exist), or reconnect to the same IP
/// when a user has performed an action.
/// Indicates that the server is overloaded and the client should either
/// connect to a different IP (when multiple targets exist), or
/// reconnect to the same IP when a user has performed an action.
Again,
#[doc(hidden)]
Tls,
@ -141,31 +144,29 @@ pub enum CloseCode {
}
impl Into<u16> for CloseCode {
fn into(self) -> u16 {
match self {
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Empty => 0,
Other(code) => code,
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Empty => 0,
Other(code) => code,
}
}
}
impl From<u16> for CloseCode {
fn from(code: u16) -> CloseCode {
match code {
1000 => Normal,
@ -182,7 +183,7 @@ impl From<u16> for CloseCode {
1012 => Restart,
1013 => Again,
1015 => Tls,
0 => Empty,
0 => Empty,
_ => Other(code),
}
}
@ -200,7 +201,6 @@ pub(crate) fn hash_key(key: &[u8]) -> String {
base64::encode(&hasher.digest().bytes())
}
#[cfg(test)]
mod test {
#![allow(unused_imports, unused_variables, dead_code)]
@ -210,9 +210,9 @@ mod test {
($from:expr => $opcode:pat) => {
match OpCode::from($from) {
e @ $opcode => (),
e => unreachable!("{:?}", e)
e => unreachable!("{:?}", e),
}
}
};
}
macro_rules! opcode_from {
@ -220,9 +220,9 @@ mod test {
let res: u8 = $from.into();
match res {
e @ $opcode => (),
e => unreachable!("{:?}", e)
e => unreachable!("{:?}", e),
}
}
};
}
#[test]

View File

@ -1,49 +1,46 @@
extern crate actix;
extern crate actix_web;
extern crate bytes;
extern crate futures;
extern crate flate2;
extern crate futures;
extern crate rand;
use std::io::Read;
use bytes::Bytes;
use flate2::read::GzDecoder;
use futures::Future;
use futures::stream::once;
use flate2::read::GzDecoder;
use rand::Rng;
use actix_web::*;
const STR: &str =
"Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_simple() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let mut srv =
test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let request = srv.get().header("x-test", "111").finish().unwrap();
let repr = format!("{:?}", request);
@ -68,23 +65,26 @@ fn test_simple() {
#[test]
fn test_with_query_parameter() {
let mut srv = test::TestServer::new(
|app| app.handler(|req: HttpRequest| match req.query().get("qp") {
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| match req.query().get("qp") {
Some(_) => HttpResponse::Ok().finish(),
None => HttpResponse::BadRequest().finish(),
}));
})
});
let request = srv.get().uri(srv.url("/?qp=5").as_str()).finish().unwrap();
let request = srv.get()
.uri(srv.url("/?qp=5").as_str())
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
}
#[test]
fn test_no_decompress() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let mut srv =
test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -111,19 +111,23 @@ fn test_no_decompress() {
#[test]
fn test_client_gzip_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.post()
.content_encoding(http::ContentEncoding::Gzip)
.body(STR).unwrap();
.body(STR)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -136,19 +140,23 @@ fn test_client_gzip_encoding() {
fn test_client_gzip_encoding_large() {
let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.post()
.content_encoding(http::ContentEncoding::Gzip)
.body(data.clone()).unwrap();
.body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -164,19 +172,23 @@ fn test_client_gzip_encoding_large_random() {
.take(100_000)
.collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.post()
.content_encoding(http::ContentEncoding::Gzip)
.body(data.clone()).unwrap();
.body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -185,22 +197,26 @@ fn test_client_gzip_encoding_large_random() {
assert_eq!(bytes, Bytes::from(data));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_client_brotli_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.client(http::Method::POST, "/")
.content_encoding(http::ContentEncoding::Br)
.body(STR).unwrap();
.body(STR)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -209,7 +225,7 @@ fn test_client_brotli_encoding() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_client_brotli_encoding_large_random() {
let data = rand::thread_rng()
@ -217,19 +233,23 @@ fn test_client_brotli_encoding_large_random() {
.take(70_000)
.collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(move |bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(move |bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.client(http::Method::POST, "/")
.content_encoding(http::ContentEncoding::Br)
.body(data.clone()).unwrap();
.body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -239,22 +259,26 @@ fn test_client_brotli_encoding_large_random() {
assert_eq!(bytes, Bytes::from(data));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_client_deflate_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.post()
.content_encoding(http::ContentEncoding::Deflate)
.body(STR).unwrap();
.body(STR)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -263,7 +287,7 @@ fn test_client_deflate_encoding() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_client_deflate_encoding_large_random() {
let data = rand::thread_rng()
@ -271,19 +295,23 @@ fn test_client_deflate_encoding_large_random() {
.take(70_000)
.collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br)
.body(bytes))
})
.responder()
})
});
// client request
let request = srv.post()
.content_encoding(http::ContentEncoding::Deflate)
.body(data.clone()).unwrap();
.body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -294,20 +322,25 @@ fn test_client_deflate_encoding_large_random() {
#[test]
fn test_client_streaming_explicit() {
let mut srv = test::TestServer::new(
|app| app.handler(
|req: HttpRequest| req.body()
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.map_err(Error::from)
.and_then(|body| {
Ok(HttpResponse::Ok()
.chunked()
.content_encoding(http::ContentEncoding::Identity)
.body(body))})
.responder()));
.chunked()
.content_encoding(http::ContentEncoding::Identity)
.body(body))
})
.responder()
})
});
let body = once(Ok(Bytes::from_static(STR.as_ref())));
let request = srv.get().body(Body::Streaming(Box::new(body))).unwrap();
let request = srv.get()
.body(Body::Streaming(Box::new(body)))
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -318,12 +351,14 @@ fn test_client_streaming_explicit() {
#[test]
fn test_body_streaming_implicit() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(Body::Streaming(Box::new(body)))}));
.body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -338,7 +373,7 @@ fn test_body_streaming_implicit() {
fn test_client_cookie_handling() {
use actix_web::http::Cookie;
fn err() -> Error {
use std::io::{ErrorKind, Error as IoError};
use std::io::{Error as IoError, ErrorKind};
// stub some generic error
Error::from(IoError::from(ErrorKind::NotFound))
}
@ -352,13 +387,12 @@ fn test_client_cookie_handling() {
// Q: are all these clones really necessary? A: Yes, possibly
let cookie1b = cookie1.clone();
let cookie2b = cookie2.clone();
let mut srv = test::TestServer::new(
move |app| {
let cookie1 = cookie1b.clone();
let cookie2 = cookie2b.clone();
app.handler(move |req: HttpRequest| {
// Check cookies were sent correctly
req.cookie("cookie1").ok_or_else(err)
let mut srv = test::TestServer::new(move |app| {
let cookie1 = cookie1b.clone();
let cookie2 = cookie2b.clone();
app.handler(move |req: HttpRequest| {
// Check cookies were sent correctly
req.cookie("cookie1").ok_or_else(err)
.and_then(|c1| if c1.value() == "value1" {
Ok(())
} else {
@ -376,8 +410,8 @@ fn test_client_cookie_handling() {
.cookie(cookie2.clone())
.finish()
)
})
});
})
});
let request = srv.get()
.cookie(cookie1.clone())

View File

@ -1,11 +1,12 @@
extern crate actix;
extern crate actix_web;
extern crate tokio_core;
extern crate bytes;
extern crate futures;
extern crate h2;
extern crate http;
extern crate bytes;
#[macro_use] extern crate serde_derive;
extern crate tokio_core;
#[macro_use]
extern crate serde_derive;
use actix_web::*;
use bytes::Bytes;
@ -19,15 +20,16 @@ struct PParam {
#[test]
fn test_path_extractor() {
let mut srv = test::TestServer::new(|app| {
app.resource(
"/{username}/index.html", |r| r.with(
|p: Path<PParam>| format!("Welcome {}!", p.username)));
}
);
app.resource("/{username}/index.html", |r| {
r.with(|p: Path<PParam>| format!("Welcome {}!", p.username))
});
});
// client request
let request = srv.get().uri(srv.url("/test/index.html"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/test/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -39,15 +41,16 @@ fn test_path_extractor() {
#[test]
fn test_query_extractor() {
let mut srv = test::TestServer::new(|app| {
app.resource(
"/index.html", |r| r.with(
|p: Query<PParam>| format!("Welcome {}!", p.username)));
}
);
app.resource("/index.html", |r| {
r.with(|p: Query<PParam>| format!("Welcome {}!", p.username))
});
});
// client request
let request = srv.get().uri(srv.url("/index.html?username=test"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/index.html?username=test"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -56,8 +59,10 @@ fn test_query_extractor() {
assert_eq!(bytes, Bytes::from_static(b"Welcome test!"));
// client request
let request = srv.get().uri(srv.url("/index.html"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
@ -65,16 +70,18 @@ fn test_query_extractor() {
#[test]
fn test_path_and_query_extractor() {
let mut srv = test::TestServer::new(|app| {
app.resource(
"/{username}/index.html", |r| r.route().with2(
|p: Path<PParam>, q: Query<PParam>|
format!("Welcome {} - {}!", p.username, q.username)));
}
);
app.resource("/{username}/index.html", |r| {
r.route().with2(|p: Path<PParam>, q: Query<PParam>| {
format!("Welcome {} - {}!", p.username, q.username)
})
});
});
// client request
let request = srv.get().uri(srv.url("/test1/index.html?username=test2"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/test1/index.html?username=test2"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -83,8 +90,10 @@ fn test_path_and_query_extractor() {
assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!"));
// client request
let request = srv.get().uri(srv.url("/test1/index.html"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/test1/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
@ -92,16 +101,19 @@ fn test_path_and_query_extractor() {
#[test]
fn test_path_and_query_extractor2() {
let mut srv = test::TestServer::new(|app| {
app.resource(
"/{username}/index.html", |r| r.route().with3(
|_: HttpRequest, p: Path<PParam>, q: Query<PParam>|
format!("Welcome {} - {}!", p.username, q.username)));
}
);
app.resource("/{username}/index.html", |r| {
r.route()
.with3(|_: HttpRequest, p: Path<PParam>, q: Query<PParam>| {
format!("Welcome {} - {}!", p.username, q.username)
})
});
});
// client request
let request = srv.get().uri(srv.url("/test1/index.html?username=test2"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/test1/index.html?username=test2"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -110,8 +122,10 @@ fn test_path_and_query_extractor2() {
assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!"));
// client request
let request = srv.get().uri(srv.url("/test1/index.html"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/test1/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
@ -123,8 +137,10 @@ fn test_non_ascii_route() {
});
// client request
let request = srv.get().uri(srv.url("/中文/index.html"))
.finish().unwrap();
let request = srv.get()
.uri(srv.url("/中文/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());

View File

@ -1,60 +1,58 @@
extern crate actix;
extern crate actix_web;
extern crate tokio_core;
extern crate bytes;
extern crate flate2;
extern crate futures;
extern crate h2;
extern crate http as modhttp;
extern crate bytes;
extern crate flate2;
extern crate rand;
extern crate tokio_core;
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
extern crate brotli2;
use std::{net, thread, time};
use std::io::{Read, Write};
use std::sync::{Arc, mpsc};
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "brotli")]
use brotli2::write::{BrotliDecoder, BrotliEncoder};
use bytes::{Bytes, BytesMut};
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::{GzEncoder, DeflateEncoder, DeflateDecoder};
#[cfg(feature="brotli")]
use brotli2::write::{BrotliEncoder, BrotliDecoder};
use futures::{Future, Stream};
use flate2::write::{DeflateDecoder, DeflateEncoder, GzEncoder};
use futures::stream::once;
use futures::{Future, Stream};
use h2::client as h2client;
use bytes::{Bytes, BytesMut};
use modhttp::Request;
use rand::Rng;
use std::io::{Read, Write};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{mpsc, Arc};
use std::{net, thread, time};
use tokio_core::net::TcpStream;
use tokio_core::reactor::Core;
use rand::Rng;
use actix::System;
use actix_web::*;
const STR: &str =
"Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_start() {
@ -63,11 +61,13 @@ fn test_start() {
thread::spawn(move || {
let sys = System::new("test");
let srv = server::new(
|| vec![App::new()
.resource(
"/", |r| r.method(http::Method::GET)
.f(|_|HttpResponse::Ok()))]);
let srv = server::new(|| {
vec![
App::new().resource("/", |r| {
r.method(http::Method::GET).f(|_| HttpResponse::Ok())
}),
]
});
let srv = srv.bind("127.0.0.1:0").unwrap();
let addr = srv.addrs()[0];
@ -79,7 +79,9 @@ fn test_start() {
let mut sys = System::new("test-server");
{
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap();
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
.finish()
.unwrap();
let response = sys.run_until_complete(req.send()).unwrap();
assert!(response.status().is_success());
}
@ -94,7 +96,9 @@ fn test_start() {
thread::sleep(time::Duration::from_millis(200));
{
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap();
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
.finish()
.unwrap();
let response = sys.run_until_complete(req.send()).unwrap();
assert!(response.status().is_success());
}
@ -108,10 +112,13 @@ fn test_shutdown() {
thread::spawn(move || {
let sys = System::new("test");
let srv = server::new(
|| vec![App::new()
.resource(
"/", |r| r.method(http::Method::GET).f(|_| HttpResponse::Ok()))]);
let srv = server::new(|| {
vec![
App::new().resource("/", |r| {
r.method(http::Method::GET).f(|_| HttpResponse::Ok())
}),
]
});
let srv = srv.bind("127.0.0.1:0").unwrap();
let addr = srv.addrs()[0];
@ -124,9 +131,11 @@ fn test_shutdown() {
let mut sys = System::new("test-server");
{
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap();
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
.finish()
.unwrap();
let response = sys.run_until_complete(req.send()).unwrap();
srv_addr.do_send(server::StopServer{graceful: true});
srv_addr.do_send(server::StopServer { graceful: true });
assert!(response.status().is_success());
}
@ -146,30 +155,31 @@ fn test_simple() {
fn test_headers() {
let data = STR.repeat(10);
let srv_data = Arc::new(data.clone());
let mut srv = test::TestServer::new(
move |app| {
let data = srv_data.clone();
app.handler(move |_| {
let mut builder = HttpResponse::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ");
}
builder.body(data.as_ref())})
});
let mut srv = test::TestServer::new(move |app| {
let data = srv_data.clone();
app.handler(move |_| {
let mut builder = HttpResponse::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
builder.body(data.as_ref())
})
});
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -182,8 +192,8 @@ fn test_headers() {
#[test]
fn test_body() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let mut srv =
test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -196,11 +206,13 @@ fn test_body() {
#[test]
fn test_body_gzip() {
let mut srv = test::TestServer::new(
|app| app.handler(
|_| HttpResponse::Ok()
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(STR)));
.body(STR)
})
});
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -221,13 +233,14 @@ fn test_body_gzip_large() {
let data = STR.repeat(10);
let srv_data = Arc::new(data.clone());
let mut srv = test::TestServer::new(
move |app| {
let data = srv_data.clone();
app.handler(
move |_| HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(data.as_ref()))});
let mut srv = test::TestServer::new(move |app| {
let data = srv_data.clone();
app.handler(move |_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(data.as_ref())
})
});
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -251,13 +264,14 @@ fn test_body_gzip_large_random() {
.collect::<String>();
let srv_data = Arc::new(data.clone());
let mut srv = test::TestServer::new(
move |app| {
let data = srv_data.clone();
app.handler(
move |_| HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(data.as_ref()))});
let mut srv = test::TestServer::new(move |app| {
let data = srv_data.clone();
app.handler(move |_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(data.as_ref())
})
});
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -276,12 +290,14 @@ fn test_body_gzip_large_random() {
#[test]
fn test_body_chunked_implicit() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip)
.body(Body::Streaming(Box::new(body)))}));
.body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -297,15 +313,17 @@ fn test_body_chunked_implicit() {
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_body_br_streaming() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br)
.body(Body::Streaming(Box::new(body)))}));
.body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -323,17 +341,23 @@ fn test_body_br_streaming() {
#[test]
fn test_head_empty() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
HttpResponse::Ok()
.content_length(STR.len() as u64).finish()}));
.content_length(STR.len() as u64)
.finish()
})
});
let request = srv.head().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap();
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
@ -344,18 +368,24 @@ fn test_head_empty() {
#[test]
fn test_head_binary() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.content_length(100).body(STR)}));
.content_length(100)
.body(STR)
})
});
let request = srv.head().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap();
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
@ -366,32 +396,38 @@ fn test_head_binary() {
#[test]
fn test_head_binary2() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(STR)
}));
})
});
let request = srv.head().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap();
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_body_length() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok()
.content_length(STR.len() as u64)
.content_encoding(http::ContentEncoding::Identity)
.body(Body::Streaming(Box::new(body)))}));
.body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -404,13 +440,15 @@ fn test_body_length() {
#[test]
fn test_body_chunked_explicit() {
let mut srv = test::TestServer::new(
|app| app.handler(|_| {
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok()
.chunked()
.content_encoding(http::ContentEncoding::Gzip)
.body(Body::Streaming(Box::new(body)))}));
.body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -428,11 +466,13 @@ fn test_body_chunked_explicit() {
#[test]
fn test_body_deflate() {
let mut srv = test::TestServer::new(
|app| app.handler(
|_| HttpResponse::Ok()
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate)
.body(STR)));
.body(STR)
})
});
// client request
let request = srv.get().disable_decompress().finish().unwrap();
@ -449,14 +489,16 @@ fn test_body_deflate() {
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_body_brotli() {
let mut srv = test::TestServer::new(
|app| app.handler(
|_| HttpResponse::Ok()
let mut srv = test::TestServer::new(|app| {
app.handler(|_| {
HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br)
.body(STR)));
.body(STR)
})
});
// client request
let request = srv.get().disable_decompress().finish().unwrap();
@ -475,14 +517,17 @@ fn test_body_brotli() {
#[test]
fn test_gzip_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
// client request
let mut e = GzEncoder::new(Vec::new(), Compression::default());
@ -491,7 +536,8 @@ fn test_gzip_encoding() {
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "gzip")
.body(enc.clone()).unwrap();
.body(enc.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -503,14 +549,17 @@ fn test_gzip_encoding() {
#[test]
fn test_gzip_encoding_large() {
let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
// client request
let mut e = GzEncoder::new(Vec::new(), Compression::default());
@ -519,7 +568,8 @@ fn test_gzip_encoding_large() {
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "gzip")
.body(enc.clone()).unwrap();
.body(enc.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -535,14 +585,17 @@ fn test_reading_gzip_encoding_large_random() {
.take(60_000)
.collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
// client request
let mut e = GzEncoder::new(Vec::new(), Compression::default());
@ -551,7 +604,8 @@ fn test_reading_gzip_encoding_large_random() {
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "gzip")
.body(enc.clone()).unwrap();
.body(enc.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -563,14 +617,17 @@ fn test_reading_gzip_encoding_large_random() {
#[test]
fn test_reading_deflate_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
let mut e = DeflateEncoder::new(Vec::new(), Compression::default());
e.write_all(STR.as_ref()).unwrap();
@ -579,7 +636,8 @@ fn test_reading_deflate_encoding() {
// client request
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "deflate")
.body(enc).unwrap();
.body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -591,14 +649,17 @@ fn test_reading_deflate_encoding() {
#[test]
fn test_reading_deflate_encoding_large() {
let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
let mut e = DeflateEncoder::new(Vec::new(), Compression::default());
e.write_all(data.as_ref()).unwrap();
@ -607,7 +668,8 @@ fn test_reading_deflate_encoding_large() {
// client request
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "deflate")
.body(enc).unwrap();
.body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -623,14 +685,17 @@ fn test_reading_deflate_encoding_large_random() {
.take(160_000)
.collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
let mut e = DeflateEncoder::new(Vec::new(), Compression::default());
e.write_all(data.as_ref()).unwrap();
@ -639,7 +704,8 @@ fn test_reading_deflate_encoding_large_random() {
// client request
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "deflate")
.body(enc).unwrap();
.body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -649,17 +715,20 @@ fn test_reading_deflate_encoding_large_random() {
assert_eq!(bytes, Bytes::from(data));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_brotli_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
let mut e = BrotliEncoder::new(Vec::new(), 5);
e.write_all(STR.as_ref()).unwrap();
@ -668,7 +737,8 @@ fn test_brotli_encoding() {
// client request
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "br")
.body(enc).unwrap();
.body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -677,18 +747,21 @@ fn test_brotli_encoding() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[cfg(feature="brotli")]
#[cfg(feature = "brotli")]
#[test]
fn test_brotli_encoding_large() {
let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
}).responder()}
));
let mut srv = test::TestServer::new(|app| {
app.handler(|req: HttpRequest| {
req.body()
.and_then(|bytes: Bytes| {
Ok(HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity)
.body(bytes))
})
.responder()
})
});
let mut e = BrotliEncoder::new(Vec::new(), 5);
e.write_all(data.as_ref()).unwrap();
@ -697,7 +770,8 @@ fn test_brotli_encoding_large() {
// client request
let request = srv.post()
.header(http::header::CONTENT_ENCODING, "br")
.body(enc).unwrap();
.body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success());
@ -708,48 +782,46 @@ fn test_brotli_encoding_large() {
#[test]
fn test_h2() {
let srv = test::TestServer::new(|app| app.handler(|_|{
HttpResponse::Ok().body(STR)
}));
let srv = test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let addr = srv.addr();
let mut core = Core::new().unwrap();
let handle = core.handle();
let tcp = TcpStream::connect(&addr, &handle);
let tcp = tcp.then(|res| {
h2client::handshake(res.unwrap())
}).then(move |res| {
let (mut client, h2) = res.unwrap();
let tcp = tcp.then(|res| h2client::handshake(res.unwrap()))
.then(move |res| {
let (mut client, h2) = res.unwrap();
let request = Request::builder()
.uri(format!("https://{}/", addr).as_str())
.body(())
.unwrap();
let (response, _) = client.send_request(request, false).unwrap();
let request = Request::builder()
.uri(format!("https://{}/", addr).as_str())
.body(())
.unwrap();
let (response, _) = client.send_request(request, false).unwrap();
// Spawn a task to run the conn...
handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e)));
// Spawn a task to run the conn...
handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e)));
response.and_then(|response| {
assert_eq!(response.status(), http::StatusCode::OK);
response.and_then(|response| {
assert_eq!(response.status(), http::StatusCode::OK);
let (_, body) = response.into_parts();
let (_, body) = response.into_parts();
body.fold(BytesMut::new(), |mut b, c| -> Result<_, h2::Error> {
b.extend(c);
Ok(b)
body.fold(BytesMut::new(), |mut b, c| -> Result<_, h2::Error> {
b.extend(c);
Ok(b)
})
})
})
});
});
let _res = core.run(tcp);
// assert_eq!(res.unwrap(), Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_application() {
let mut srv = test::TestServer::with_factory(
|| App::new().resource("/", |r| r.f(|_| HttpResponse::Ok())));
let mut srv = test::TestServer::with_factory(|| {
App::new().resource("/", |r| r.f(|_| HttpResponse::Ok()))
});
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -764,17 +836,28 @@ struct MiddlewareTest {
impl<S> middleware::Middleware<S> for MiddlewareTest {
fn start(&self, _: &mut HttpRequest<S>) -> Result<middleware::Started> {
self.start.store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
self.start.store(
self.start.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
Ok(middleware::Started::Done)
}
fn response(&self, _: &mut HttpRequest<S>, resp: HttpResponse) -> Result<middleware::Response> {
self.response.store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
fn response(
&self, _: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<middleware::Response> {
self.response.store(
self.response.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
Ok(middleware::Response::Done(resp))
}
fn finish(&self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished {
self.finish.store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
self.finish.store(
self.finish.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
middleware::Finished::Done
}
}
@ -789,12 +872,13 @@ fn test_middlewares() {
let act_num2 = Arc::clone(&num2);
let act_num3 = Arc::clone(&num3);
let mut srv = test::TestServer::new(
move |app| app.middleware(MiddlewareTest{start: Arc::clone(&act_num1),
response: Arc::clone(&act_num2),
finish: Arc::clone(&act_num3)})
.handler(|_| HttpResponse::Ok())
);
let mut srv = test::TestServer::new(move |app| {
app.middleware(MiddlewareTest {
start: Arc::clone(&act_num1),
response: Arc::clone(&act_num2),
finish: Arc::clone(&act_num3),
}).handler(|_| HttpResponse::Ok())
});
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();
@ -805,7 +889,6 @@ fn test_middlewares() {
assert_eq!(num3.load(Ordering::Relaxed), 1);
}
#[test]
fn test_resource_middlewares() {
let num1 = Arc::new(AtomicUsize::new(0));
@ -816,13 +899,13 @@ fn test_resource_middlewares() {
let act_num2 = Arc::clone(&num2);
let act_num3 = Arc::clone(&num3);
let mut srv = test::TestServer::new(
move |app| app
.middleware(MiddlewareTest{start: Arc::clone(&act_num1),
response: Arc::clone(&act_num2),
finish: Arc::clone(&act_num3)})
.handler(|_| HttpResponse::Ok())
);
let mut srv = test::TestServer::new(move |app| {
app.middleware(MiddlewareTest {
start: Arc::clone(&act_num1),
response: Arc::clone(&act_num2),
finish: Arc::clone(&act_num3),
}).handler(|_| HttpResponse::Ok())
});
let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap();

View File

@ -1,19 +1,19 @@
extern crate actix;
extern crate actix_web;
extern crate bytes;
extern crate futures;
extern crate http;
extern crate bytes;
extern crate rand;
use bytes::Bytes;
use futures::Stream;
use rand::Rng;
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
extern crate openssl;
use actix_web::*;
use actix::prelude::*;
use actix_web::*;
struct Ws;
@ -22,7 +22,6 @@ impl Actor for Ws {
}
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg {
ws::Message::Ping(msg) => ctx.pong(&msg),
@ -36,8 +35,7 @@ impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
#[test]
fn test_simple() {
let mut srv = test::TestServer::new(
|app| app.handler(|req| ws::start(req, Ws)));
let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws)));
let (reader, mut writer) = srv.ws().unwrap();
writer.text("text");
@ -46,7 +44,12 @@ fn test_simple() {
writer.binary(b"text".as_ref());
let (item, reader) = srv.execute(reader.into_future()).unwrap();
assert_eq!(item, Some(ws::Message::Binary(Bytes::from_static(b"text").into())));
assert_eq!(
item,
Some(ws::Message::Binary(
Bytes::from_static(b"text").into()
))
);
writer.ping("ping");
let (item, reader) = srv.execute(reader.into_future()).unwrap();
@ -64,8 +67,7 @@ fn test_large_text() {
.take(65_536)
.collect::<String>();
let mut srv = test::TestServer::new(
|app| app.handler(|req| ws::start(req, Ws)));
let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws)));
let (mut reader, mut writer) = srv.ws().unwrap();
for _ in 0..100 {
@ -83,15 +85,17 @@ fn test_large_bin() {
.take(65_536)
.collect::<String>();
let mut srv = test::TestServer::new(
|app| app.handler(|req| ws::start(req, Ws)));
let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws)));
let (mut reader, mut writer) = srv.ws().unwrap();
for _ in 0..100 {
writer.binary(data.clone());
let (item, r) = srv.execute(reader.into_future()).unwrap();
reader = r;
assert_eq!(item, Some(ws::Message::Binary(Binary::from(data.clone()))));
assert_eq!(
item,
Some(ws::Message::Binary(Binary::from(data.clone())))
);
}
}
@ -115,18 +119,19 @@ impl Ws2 {
} else {
ctx.text("0".repeat(65_536));
}
ctx.drain().and_then(|_, act, ctx| {
act.count += 1;
if act.count != 10_000 {
act.send(ctx);
}
actix::fut::ok(())
}).wait(ctx);
ctx.drain()
.and_then(|_, act, ctx| {
act.count += 1;
if act.count != 10_000 {
act.send(ctx);
}
actix::fut::ok(())
})
.wait(ctx);
}
}
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws2 {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg {
ws::Message::Ping(msg) => ctx.pong(&msg),
@ -142,8 +147,17 @@ impl StreamHandler<ws::Message, ws::ProtocolError> for Ws2 {
fn test_server_send_text() {
let data = Some(ws::Message::Text("0".repeat(65_536)));
let mut srv = test::TestServer::new(
|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false})));
let mut srv = test::TestServer::new(|app| {
app.handler(|req| {
ws::start(
req,
Ws2 {
count: 0,
bin: false,
},
)
})
});
let (mut reader, _writer) = srv.ws().unwrap();
for _ in 0..10_000 {
@ -157,8 +171,17 @@ fn test_server_send_text() {
fn test_server_send_bin() {
let data = Some(ws::Message::Binary(Binary::from("0".repeat(65_536))));
let mut srv = test::TestServer::new(
|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: true})));
let mut srv = test::TestServer::new(|app| {
app.handler(|req| {
ws::start(
req,
Ws2 {
count: 0,
bin: true,
},
)
})
});
let (mut reader, _writer) = srv.ws().unwrap();
for _ in 0..10_000 {
@ -169,19 +192,33 @@ fn test_server_send_bin() {
}
#[test]
#[cfg(feature="alpn")]
#[cfg(feature = "alpn")]
fn test_ws_server_ssl() {
extern crate openssl;
use openssl::ssl::{SslMethod, SslAcceptor, SslFiletype};
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder.set_private_key_file("tests/key.pem", SslFiletype::PEM).unwrap();
builder.set_certificate_chain_file("tests/cert.pem").unwrap();
builder
.set_private_key_file("tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("tests/cert.pem")
.unwrap();
let mut srv = test::TestServer::build()
.ssl(builder.build())
.start(|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false})));
.start(|app| {
app.handler(|req| {
ws::start(
req,
Ws2 {
count: 0,
bin: false,
},
)
})
});
let (mut reader, _writer) = srv.ws().unwrap();
let data = Some(ws::Message::Text("0".repeat(65_536)));

View File

@ -3,25 +3,24 @@
#![allow(unused_variables)]
extern crate actix;
extern crate actix_web;
extern crate clap;
extern crate env_logger;
extern crate futures;
extern crate tokio_core;
extern crate url;
extern crate clap;
extern crate num_cpus;
extern crate rand;
extern crate time;
extern crate num_cpus;
extern crate tokio_core;
extern crate url;
use std::time::Duration;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use futures::Future;
use rand::{thread_rng, Rng};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use actix::prelude::*;
use actix_web::ws;
fn main() {
::std::env::set_var("RUST_LOG", "actix_web=info");
let _ = env_logger::init();
@ -62,16 +61,20 @@ fn main() {
let sample_rate = parse_u64_default(matches.value_of("sample-rate"), 1) as usize;
let perf_counters = Arc::new(PerfCounters::new());
let payload = Arc::new(thread_rng()
.gen_ascii_chars()
.take(payload_size)
.collect::<String>());
let payload = Arc::new(
thread_rng()
.gen_ascii_chars()
.take(payload_size)
.collect::<String>(),
);
let sys = actix::System::new("ws-client");
let _: () = Perf{counters: perf_counters.clone(),
payload: payload.len(),
sample_rate_secs: sample_rate}.start();
let _: () = Perf {
counters: perf_counters.clone(),
payload: payload.len(),
sample_rate_secs: sample_rate,
}.start();
for t in 0..threads {
let pl = payload.clone();
@ -79,46 +82,54 @@ fn main() {
let perf = perf_counters.clone();
let addr = Arbiter::new(format!("test {}", t));
addr.do_send(actix::msgs::Execute::new(move || -> Result<(), ()> {
for _ in 0..concurrency {
let pl2 = pl.clone();
let perf2 = perf.clone();
let ws2 = ws.clone();
addr.do_send(actix::msgs::Execute::new(
move || -> Result<(), ()> {
for _ in 0..concurrency {
let pl2 = pl.clone();
let perf2 = perf.clone();
let ws2 = ws.clone();
Arbiter::handle().spawn(
ws::Client::new(&ws)
.write_buffer_capacity(0)
.connect()
.map_err(|e| {
println!("Error: {}", e);
//Arbiter::system().do_send(actix::msgs::SystemExit(0));
()
})
.map(move |(reader, writer)| {
let addr: Addr<Syn, _> = ChatClient::create(move |ctx| {
ChatClient::add_stream(reader, ctx);
ChatClient{url: ws2,
conn: writer,
payload: pl2,
bin: bin,
ts: time::precise_time_ns(),
perf_counters: perf2,
sent: 0,
max_payload_size: max_payload_size,
}
});
})
);
}
Ok(())
}));
Arbiter::handle().spawn(
ws::Client::new(&ws)
.write_buffer_capacity(0)
.connect()
.map_err(|e| {
println!("Error: {}", e);
//Arbiter::system().do_send(actix::msgs::SystemExit(0));
()
})
.map(move |(reader, writer)| {
let addr: Addr<Syn, _> =
ChatClient::create(move |ctx| {
ChatClient::add_stream(reader, ctx);
ChatClient {
url: ws2,
conn: writer,
payload: pl2,
bin: bin,
ts: time::precise_time_ns(),
perf_counters: perf2,
sent: 0,
max_payload_size: max_payload_size,
}
});
}),
);
}
Ok(())
},
));
}
let res = sys.run();
}
fn parse_u64_default(input: Option<&str>, default: u64) -> u64 {
input.map(|v| v.parse().expect(&format!("not a valid number: {}", v)))
input
.map(|v| {
v.parse()
.expect(&format!("not a valid number: {}", v))
})
.unwrap_or(default)
}
@ -138,29 +149,32 @@ impl Actor for Perf {
impl Perf {
fn sample_rate(&self, ctx: &mut Context<Self>) {
ctx.run_later(Duration::new(self.sample_rate_secs as u64, 0), |act, ctx| {
let req_count = act.counters.pull_request_count();
if req_count != 0 {
let conns = act.counters.pull_connections_count();
let latency = act.counters.pull_latency_ns();
let latency_max = act.counters.pull_latency_max_ns();
println!(
"rate: {}, conns: {}, throughput: {:?} kb, latency: {}, latency max: {}",
req_count / act.sample_rate_secs,
conns / act.sample_rate_secs,
(((req_count * act.payload) as f64) / 1024.0) /
act.sample_rate_secs as f64,
time::Duration::nanoseconds((latency / req_count as u64) as i64),
time::Duration::nanoseconds(latency_max as i64)
);
}
ctx.run_later(
Duration::new(self.sample_rate_secs as u64, 0),
|act, ctx| {
let req_count = act.counters.pull_request_count();
if req_count != 0 {
let conns = act.counters.pull_connections_count();
let latency = act.counters.pull_latency_ns();
let latency_max = act.counters.pull_latency_max_ns();
println!(
"rate: {}, conns: {}, throughput: {:?} kb, latency: {}, latency max: {}",
req_count / act.sample_rate_secs,
conns / act.sample_rate_secs,
(((req_count * act.payload) as f64) / 1024.0)
/ act.sample_rate_secs as f64,
time::Duration::nanoseconds((latency / req_count as u64) as i64),
time::Duration::nanoseconds(latency_max as i64)
);
}
act.sample_rate(ctx);
});
act.sample_rate(ctx);
},
);
}
}
struct ChatClient{
struct ChatClient {
url: String,
conn: ws::ClientWriter,
payload: Arc<String>,
@ -181,7 +195,6 @@ impl Actor for ChatClient {
}
impl ChatClient {
fn send_text(&mut self) -> bool {
self.sent += self.payload.len();
@ -193,7 +206,8 @@ impl ChatClient {
let max_payload_size = self.max_payload_size;
Arbiter::handle().spawn(
ws::Client::new(&self.url).connect()
ws::Client::new(&self.url)
.connect()
.map_err(|e| {
println!("Error: {}", e);
Arbiter::system().do_send(actix::msgs::SystemExit(0));
@ -202,17 +216,18 @@ impl ChatClient {
.map(move |(reader, writer)| {
let addr: Addr<Syn, _> = ChatClient::create(move |ctx| {
ChatClient::add_stream(reader, ctx);
ChatClient{url: ws,
conn: writer,
payload: pl,
bin: bin,
ts: time::precise_time_ns(),
perf_counters: perf_counters,
sent: 0,
max_payload_size: max_payload_size,
ChatClient {
url: ws,
conn: writer,
payload: pl,
bin: bin,
ts: time::precise_time_ns(),
perf_counters: perf_counters,
sent: 0,
max_payload_size: max_payload_size,
}
});
})
}),
);
false
} else {
@ -229,7 +244,6 @@ impl ChatClient {
/// Handle server websocket messages
impl StreamHandler<ws::Message, ws::ProtocolError> for ChatClient {
fn finished(&mut self, ctx: &mut Context<Self>) {
ctx.stop()
}
@ -239,25 +253,25 @@ impl StreamHandler<ws::Message, ws::ProtocolError> for ChatClient {
ws::Message::Text(txt) => {
if txt == self.payload.as_ref().as_str() {
self.perf_counters.register_request();
self.perf_counters.register_latency(time::precise_time_ns() - self.ts);
self.perf_counters
.register_latency(time::precise_time_ns() - self.ts);
if !self.send_text() {
ctx.stop();
}
} else {
println!("not eaqual");
}
},
_ => ()
}
_ => (),
}
}
}
pub struct PerfCounters {
req: AtomicUsize,
conn: AtomicUsize,
lat: AtomicUsize,
lat_max: AtomicUsize
lat_max: AtomicUsize,
}
impl PerfCounters {
@ -299,7 +313,11 @@ impl PerfCounters {
self.lat.fetch_add(nanos, Ordering::SeqCst);
loop {
let current = self.lat_max.load(Ordering::SeqCst);
if current >= nanos || self.lat_max.compare_and_swap(current, nanos, Ordering::SeqCst) == current {
if current >= nanos
|| self.lat_max
.compare_and_swap(current, nanos, Ordering::SeqCst)
== current
{
break;
}
}