1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-24 00:21:08 +01:00

add rustfmt config

This commit is contained in:
Nikolay Kim 2018-04-13 16:02:01 -07:00
parent 95f6277007
commit 113f5ad1a8
91 changed files with 8057 additions and 5509 deletions

View File

@ -3,7 +3,6 @@ extern crate version_check;
use std::{env, fs}; use std::{env, fs};
#[cfg(unix)] #[cfg(unix)]
fn main() { fn main() {
println!("cargo:rerun-if-env-changed=USE_SKEPTIC"); println!("cargo:rerun-if-env-changed=USE_SKEPTIC");
@ -11,23 +10,23 @@ fn main() {
if env::var("USE_SKEPTIC").is_ok() { if env::var("USE_SKEPTIC").is_ok() {
let _ = fs::remove_file(f); let _ = fs::remove_file(f);
// generates doc tests for `README.md`. // generates doc tests for `README.md`.
skeptic::generate_doc_tests( skeptic::generate_doc_tests(&[
&[// "README.md", // "README.md",
"guide/src/qs_1.md", "guide/src/qs_1.md",
"guide/src/qs_2.md", "guide/src/qs_2.md",
"guide/src/qs_3.md", "guide/src/qs_3.md",
"guide/src/qs_3_5.md", "guide/src/qs_3_5.md",
"guide/src/qs_4.md", "guide/src/qs_4.md",
"guide/src/qs_4_5.md", "guide/src/qs_4_5.md",
"guide/src/qs_5.md", "guide/src/qs_5.md",
"guide/src/qs_7.md", "guide/src/qs_7.md",
"guide/src/qs_8.md", "guide/src/qs_8.md",
"guide/src/qs_9.md", "guide/src/qs_9.md",
"guide/src/qs_10.md", "guide/src/qs_10.md",
"guide/src/qs_12.md", "guide/src/qs_12.md",
"guide/src/qs_13.md", "guide/src/qs_13.md",
"guide/src/qs_14.md", "guide/src/qs_14.md",
]); ]);
} else { } else {
let _ = fs::File::create(f); let _ = fs::File::create(f);
} }

7
rustfmt.toml Normal file
View File

@ -0,0 +1,7 @@
max_width = 89
reorder_imports = true
reorder_imports_in_group = true
reorder_imported_names = true
wrap_comments = true
fn_args_density = "Compressed"
#use_small_heuristics = false

View File

@ -1,24 +1,24 @@
use std::mem;
use std::rc::Rc;
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::mem;
use std::rc::Rc;
use http::Method;
use handler::Reply; use handler::Reply;
use router::{Router, Resource}; use handler::{FromRequest, Handler, Responder, RouteHandler, WrapHandler};
use resource::{ResourceHandler};
use header::ContentEncoding; use header::ContentEncoding;
use handler::{Handler, RouteHandler, WrapHandler, FromRequest, Responder}; use http::Method;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use pipeline::{Pipeline, PipelineHandler, HandlerType};
use middleware::Middleware; use middleware::Middleware;
use server::{HttpHandler, IntoHttpHandler, HttpHandlerTask, ServerSettings}; use pipeline::{HandlerType, Pipeline, PipelineHandler};
use resource::ResourceHandler;
use router::{Resource, Router};
use server::{HttpHandler, HttpHandlerTask, IntoHttpHandler, ServerSettings};
#[deprecated(since="0.5.0", note="please use `actix_web::App` instead")] #[deprecated(since = "0.5.0", note = "please use `actix_web::App` instead")]
pub type Application<S> = App<S>; pub type Application<S> = App<S>;
/// Application /// Application
pub struct HttpApplication<S=()> { pub struct HttpApplication<S = ()> {
state: Rc<S>, state: Rc<S>,
prefix: String, prefix: String,
prefix_len: usize, prefix_len: usize,
@ -36,28 +36,25 @@ pub(crate) struct Inner<S> {
} }
impl<S: 'static> PipelineHandler<S> for Inner<S> { impl<S: 'static> PipelineHandler<S> for Inner<S> {
fn encoding(&self) -> ContentEncoding { fn encoding(&self) -> ContentEncoding {
self.encoding self.encoding
} }
fn handle(&mut self, req: HttpRequest<S>, htype: HandlerType) -> Reply { fn handle(&mut self, req: HttpRequest<S>, htype: HandlerType) -> Reply {
match htype { match htype {
HandlerType::Normal(idx) => HandlerType::Normal(idx) => {
self.resources[idx].handle(req, Some(&mut self.default)), self.resources[idx].handle(req, Some(&mut self.default))
HandlerType::Handler(idx) => }
self.handlers[idx].1.handle(req), HandlerType::Handler(idx) => self.handlers[idx].1.handle(req),
HandlerType::Default => HandlerType::Default => self.default.handle(req, None),
self.default.handle(req, None)
} }
} }
} }
impl<S: 'static> HttpApplication<S> { impl<S: 'static> HttpApplication<S> {
#[inline] #[inline]
fn as_ref(&self) -> &Inner<S> { fn as_ref(&self) -> &Inner<S> {
unsafe{&*self.inner.get()} unsafe { &*self.inner.get() }
} }
#[inline] #[inline]
@ -70,20 +67,21 @@ impl<S: 'static> HttpApplication<S> {
let &(ref prefix, _) = &inner.handlers[idx]; let &(ref prefix, _) = &inner.handlers[idx];
let m = { let m = {
let path = &req.path()[inner.prefix..]; let path = &req.path()[inner.prefix..];
path.starts_with(prefix) && ( path.starts_with(prefix)
path.len() == prefix.len() || && (path.len() == prefix.len()
path.split_at(prefix.len()).1.starts_with('/')) || path.split_at(prefix.len()).1.starts_with('/'))
}; };
if m { if m {
let path: &'static str = unsafe { let path: &'static str = unsafe {
mem::transmute(&req.path()[inner.prefix+prefix.len()..]) }; mem::transmute(&req.path()[inner.prefix + prefix.len()..])
};
if path.is_empty() { if path.is_empty() {
req.match_info_mut().add("tail", ""); req.match_info_mut().add("tail", "");
} else { } else {
req.match_info_mut().add("tail", path.split_at(1).1); req.match_info_mut().add("tail", path.split_at(1).1);
} }
return HandlerType::Handler(idx) return HandlerType::Handler(idx);
} }
} }
HandlerType::Default HandlerType::Default
@ -93,7 +91,7 @@ impl<S: 'static> HttpApplication<S> {
#[cfg(test)] #[cfg(test)]
pub(crate) fn run(&mut self, mut req: HttpRequest<S>) -> Reply { pub(crate) fn run(&mut self, mut req: HttpRequest<S>) -> Reply {
let tp = self.get_handler(&mut req); let tp = self.get_handler(&mut req);
unsafe{&mut *self.inner.get()}.handle(req, tp) unsafe { &mut *self.inner.get() }.handle(req, tp)
} }
#[cfg(test)] #[cfg(test)]
@ -103,19 +101,23 @@ impl<S: 'static> HttpApplication<S> {
} }
impl<S: 'static> HttpHandler for HttpApplication<S> { impl<S: 'static> HttpHandler for HttpApplication<S> {
fn handle(&mut self, req: HttpRequest) -> Result<Box<HttpHandlerTask>, HttpRequest> { fn handle(&mut self, req: HttpRequest) -> Result<Box<HttpHandlerTask>, HttpRequest> {
let m = { let m = {
let path = req.path(); let path = req.path();
path.starts_with(&self.prefix) && ( path.starts_with(&self.prefix)
path.len() == self.prefix_len || && (path.len() == self.prefix_len
path.split_at(self.prefix_len).1.starts_with('/')) || path.split_at(self.prefix_len).1.starts_with('/'))
}; };
if m { if m {
let mut req = req.with_state(Rc::clone(&self.state), self.router.clone()); let mut req = req.with_state(Rc::clone(&self.state), self.router.clone());
let tp = self.get_handler(&mut req); let tp = self.get_handler(&mut req);
let inner = Rc::clone(&self.inner); let inner = Rc::clone(&self.inner);
Ok(Box::new(Pipeline::new(req, Rc::clone(&self.middlewares), inner, tp))) Ok(Box::new(Pipeline::new(
req,
Rc::clone(&self.middlewares),
inner,
tp,
)))
} else { } else {
Err(req) Err(req)
} }
@ -134,14 +136,14 @@ struct ApplicationParts<S> {
middlewares: Vec<Box<Middleware<S>>>, middlewares: Vec<Box<Middleware<S>>>,
} }
/// Structure that follows the builder pattern for building application instances. /// Structure that follows the builder pattern for building application
pub struct App<S=()> { /// instances.
pub struct App<S = ()> {
parts: Option<ApplicationParts<S>>, parts: Option<ApplicationParts<S>>,
} }
impl App<()> { impl App<()> {
/// Create application with empty state. Application can
/// Create application with empty state. Application can
/// be configured with a builder-like pattern. /// be configured with a builder-like pattern.
pub fn new() -> App<()> { pub fn new() -> App<()> {
App { App {
@ -155,7 +157,7 @@ impl App<()> {
external: HashMap::new(), external: HashMap::new(),
encoding: ContentEncoding::Auto, encoding: ContentEncoding::Auto,
middlewares: Vec::new(), middlewares: Vec::new(),
}) }),
} }
} }
} }
@ -166,8 +168,10 @@ impl Default for App<()> {
} }
} }
impl<S> App<S> where S: 'static { impl<S> App<S>
where
S: 'static,
{
/// Create application with specified state. Application can be /// Create application with specified state. Application can be
/// configured with a builder-like pattern. /// configured with a builder-like pattern.
/// ///
@ -185,7 +189,7 @@ impl<S> App<S> where S: 'static {
external: HashMap::new(), external: HashMap::new(),
middlewares: Vec::new(), middlewares: Vec::new(),
encoding: ContentEncoding::Auto, encoding: ContentEncoding::Auto,
}) }),
} }
} }
@ -256,20 +260,22 @@ impl<S> App<S> where S: 'static {
/// } /// }
/// ``` /// ```
pub fn route<T, F, R>(mut self, path: &str, method: Method, f: F) -> App<S> pub fn route<T, F, R>(mut self, path: &str, method: Method, f: F) -> App<S>
where F: Fn(T) -> R + 'static, where
R: Responder + 'static, F: Fn(T) -> R + 'static,
T: FromRequest<S> + 'static, R: Responder + 'static,
T: FromRequest<S> + 'static,
{ {
{ {
let parts: &mut ApplicationParts<S> = unsafe{ let parts: &mut ApplicationParts<S> = unsafe {
mem::transmute(self.parts.as_mut().expect("Use after finish"))}; mem::transmute(self.parts.as_mut().expect("Use after finish"))
};
// get resource handler // get resource handler
for &mut (ref pattern, ref mut handler) in &mut parts.resources { for &mut (ref pattern, ref mut handler) in &mut parts.resources {
if let Some(ref mut handler) = *handler { if let Some(ref mut handler) = *handler {
if pattern.pattern() == path { if pattern.pattern() == path {
handler.method(method).with(f); handler.method(method).with(f);
return self return self;
} }
} }
} }
@ -315,7 +321,8 @@ impl<S> App<S> where S: 'static {
/// } /// }
/// ``` /// ```
pub fn resource<F, R>(mut self, path: &str, f: F) -> App<S> pub fn resource<F, R>(mut self, path: &str, f: F) -> App<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{ {
{ {
let parts = self.parts.as_mut().expect("Use after finish"); let parts = self.parts.as_mut().expect("Use after finish");
@ -334,13 +341,17 @@ impl<S> App<S> where S: 'static {
#[doc(hidden)] #[doc(hidden)]
pub fn register_resource(&mut self, path: &str, resource: ResourceHandler<S>) { pub fn register_resource(&mut self, path: &str, resource: ResourceHandler<S>) {
let pattern = Resource::new(resource.get_name(), path); let pattern = Resource::new(resource.get_name(), path);
self.parts.as_mut().expect("Use after finish") self.parts
.resources.push((pattern, Some(resource))); .as_mut()
.expect("Use after finish")
.resources
.push((pattern, Some(resource)));
} }
/// Default resource to be used if no matching route could be found. /// Default resource to be used if no matching route could be found.
pub fn default_resource<F, R>(mut self, f: F) -> App<S> pub fn default_resource<F, R>(mut self, f: F) -> App<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{ {
{ {
let parts = self.parts.as_mut().expect("Use after finish"); let parts = self.parts.as_mut().expect("Use after finish");
@ -350,8 +361,7 @@ impl<S> App<S> where S: 'static {
} }
/// Set default content encoding. `ContentEncoding::Auto` is set by default. /// Set default content encoding. `ContentEncoding::Auto` is set by default.
pub fn default_encoding(mut self, encoding: ContentEncoding) -> App<S> pub fn default_encoding(mut self, encoding: ContentEncoding) -> App<S> {
{
{ {
let parts = self.parts.as_mut().expect("Use after finish"); let parts = self.parts.as_mut().expect("Use after finish");
parts.encoding = encoding; parts.encoding = encoding;
@ -383,7 +393,9 @@ impl<S> App<S> where S: 'static {
/// } /// }
/// ``` /// ```
pub fn external_resource<T, U>(mut self, name: T, url: U) -> App<S> pub fn external_resource<T, U>(mut self, name: T, url: U) -> App<S>
where T: AsRef<str>, U: AsRef<str> where
T: AsRef<str>,
U: AsRef<str>,
{ {
{ {
let parts = self.parts.as_mut().expect("Use after finish"); let parts = self.parts.as_mut().expect("Use after finish");
@ -393,7 +405,8 @@ impl<S> App<S> where S: 'static {
} }
parts.external.insert( parts.external.insert(
String::from(name.as_ref()), String::from(name.as_ref()),
Resource::external(name.as_ref(), url.as_ref())); Resource::external(name.as_ref(), url.as_ref()),
);
} }
self self
} }
@ -419,8 +432,7 @@ impl<S> App<S> where S: 'static {
/// }}); /// }});
/// } /// }
/// ``` /// ```
pub fn handler<H: Handler<S>>(mut self, path: &str, handler: H) -> App<S> pub fn handler<H: Handler<S>>(mut self, path: &str, handler: H) -> App<S> {
{
{ {
let mut path = path.trim().trim_right_matches('/').to_owned(); let mut path = path.trim().trim_right_matches('/').to_owned();
if !path.is_empty() && !path.starts_with('/') { if !path.is_empty() && !path.starts_with('/') {
@ -428,15 +440,20 @@ impl<S> App<S> where S: 'static {
} }
let parts = self.parts.as_mut().expect("Use after finish"); let parts = self.parts.as_mut().expect("Use after finish");
parts.handlers.push((path, Box::new(WrapHandler::new(handler)))); parts
.handlers
.push((path, Box::new(WrapHandler::new(handler))));
} }
self self
} }
/// Register a middleware. /// Register a middleware.
pub fn middleware<M: Middleware<S>>(mut self, mw: M) -> App<S> { pub fn middleware<M: Middleware<S>>(mut self, mw: M) -> App<S> {
self.parts.as_mut().expect("Use after finish") self.parts
.middlewares.push(Box::new(mw)); .as_mut()
.expect("Use after finish")
.middlewares
.push(Box::new(mw));
self self
} }
@ -468,7 +485,8 @@ impl<S> App<S> where S: 'static {
/// } /// }
/// ``` /// ```
pub fn configure<F>(self, cfg: F) -> App<S> pub fn configure<F>(self, cfg: F) -> App<S>
where F: Fn(App<S>) -> App<S> where
F: Fn(App<S>) -> App<S>,
{ {
cfg(self) cfg(self)
} }
@ -490,15 +508,13 @@ impl<S> App<S> where S: 'static {
let (router, resources) = Router::new(&prefix, parts.settings, resources); let (router, resources) = Router::new(&prefix, parts.settings, resources);
let inner = Rc::new(UnsafeCell::new( let inner = Rc::new(UnsafeCell::new(Inner {
Inner { prefix: prefix_len,
prefix: prefix_len, default: parts.default,
default: parts.default, encoding: parts.encoding,
encoding: parts.encoding, handlers: parts.handlers,
handlers: parts.handlers, resources,
resources, }));
}
));
HttpApplication { HttpApplication {
state: Rc::new(parts.state), state: Rc::new(parts.state),
@ -582,14 +598,13 @@ impl<S: 'static> Iterator for App<S> {
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use http::StatusCode;
use super::*; use super::*;
use test::TestRequest; use http::StatusCode;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use test::TestRequest;
#[test] #[test]
fn test_default_resource() { fn test_default_resource() {
@ -603,14 +618,20 @@ mod tests {
let req = TestRequest::with_uri("/blah").finish(); let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let mut app = App::new() let mut app = App::new()
.default_resource(|r| r.f(|_| HttpResponse::MethodNotAllowed())) .default_resource(|r| r.f(|_| HttpResponse::MethodNotAllowed()))
.finish(); .finish();
let req = TestRequest::with_uri("/blah").finish(); let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::METHOD_NOT_ALLOWED
);
} }
#[test] #[test]
@ -627,7 +648,8 @@ mod tests {
let mut app = App::with_state(10) let mut app = App::with_state(10)
.resource("/", |r| r.f(|_| HttpResponse::Ok())) .resource("/", |r| r.f(|_| HttpResponse::Ok()))
.finish(); .finish();
let req = HttpRequest::default().with_state(Rc::clone(&app.state), app.router.clone()); let req =
HttpRequest::default().with_state(Rc::clone(&app.state), app.router.clone());
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK);
} }
@ -675,11 +697,17 @@ mod tests {
let req = TestRequest::with_uri("/testapp").finish(); let req = TestRequest::with_uri("/testapp").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/blah").finish(); let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
} }
#[test] #[test]
@ -702,11 +730,17 @@ mod tests {
let req = TestRequest::with_uri("/testapp").finish(); let req = TestRequest::with_uri("/testapp").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/blah").finish(); let req = TestRequest::with_uri("/blah").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
} }
#[test] #[test]
@ -730,31 +764,53 @@ mod tests {
let req = TestRequest::with_uri("/prefix/testapp").finish(); let req = TestRequest::with_uri("/prefix/testapp").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/prefix/blah").finish(); let req = TestRequest::with_uri("/prefix/blah").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
} }
#[test] #[test]
fn test_route() { fn test_route() {
let mut app = App::new() let mut app = App::new()
.route("/test", Method::GET, |_: HttpRequest| HttpResponse::Ok()) .route("/test", Method::GET, |_: HttpRequest| {
.route("/test", Method::POST, |_: HttpRequest| HttpResponse::Created()) HttpResponse::Ok()
})
.route("/test", Method::POST, |_: HttpRequest| {
HttpResponse::Created()
})
.finish(); .finish();
let req = TestRequest::with_uri("/test").method(Method::GET).finish(); let req = TestRequest::with_uri("/test")
.method(Method::GET)
.finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK);
let req = TestRequest::with_uri("/test").method(Method::POST).finish(); let req = TestRequest::with_uri("/test")
.method(Method::POST)
.finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::CREATED); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::CREATED
);
let req = TestRequest::with_uri("/test").method(Method::HEAD).finish(); let req = TestRequest::with_uri("/test")
.method(Method::HEAD)
.finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
} }
#[test] #[test]
@ -766,7 +822,10 @@ mod tests {
let req = TestRequest::with_uri("/test").finish(); let req = TestRequest::with_uri("/test").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/app/test").finish(); let req = TestRequest::with_uri("/app/test").finish();
let resp = app.run(req); let resp = app.run(req);
@ -782,12 +841,16 @@ mod tests {
let req = TestRequest::with_uri("/app/testapp").finish(); let req = TestRequest::with_uri("/app/testapp").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
let req = TestRequest::with_uri("/app/blah").finish(); let req = TestRequest::with_uri("/app/blah").finish();
let resp = app.run(req); let resp = app.run(req);
assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); assert_eq!(
resp.as_response().unwrap().status(),
StatusCode::NOT_FOUND
);
} }
} }

View File

@ -1,18 +1,17 @@
use std::{fmt, mem};
use std::rc::Rc;
use std::sync::Arc;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::Stream; use futures::Stream;
use std::rc::Rc;
use std::sync::Arc;
use std::{fmt, mem};
use error::Error;
use context::ActorHttpContext; use context::ActorHttpContext;
use error::Error;
use handler::Responder; use handler::Responder;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
/// Type represent streaming body /// Type represent streaming body
pub type BodyStream = Box<Stream<Item=Bytes, Error=Error>>; pub type BodyStream = Box<Stream<Item = Bytes, Error = Error>>;
/// Represents various types of http message body. /// Represents various types of http message body.
pub enum Body { pub enum Body {
@ -50,7 +49,7 @@ impl Body {
pub fn is_streaming(&self) -> bool { pub fn is_streaming(&self) -> bool {
match *self { match *self {
Body::Streaming(_) | Body::Actor(_) => true, Body::Streaming(_) | Body::Actor(_) => true,
_ => false _ => false,
} }
} }
@ -59,7 +58,7 @@ impl Body {
pub fn is_binary(&self) -> bool { pub fn is_binary(&self) -> bool {
match *self { match *self {
Body::Binary(_) => true, Body::Binary(_) => true,
_ => false _ => false,
} }
} }
@ -96,7 +95,10 @@ impl fmt::Debug for Body {
} }
} }
impl<T> From<T> for Body where T: Into<Binary>{ impl<T> From<T> for Body
where
T: Into<Binary>,
{
fn from(b: T) -> Body { fn from(b: T) -> Body {
Body::Binary(b.into()) Body::Binary(b.into())
} }
@ -257,8 +259,8 @@ impl Responder for Binary {
fn respond_to(self, _: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, _: HttpRequest) -> Result<HttpResponse, Error> {
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
.content_type("application/octet-stream") .content_type("application/octet-stream")
.body(self)) .body(self))
} }
} }
@ -349,7 +351,7 @@ mod tests {
#[test] #[test]
fn test_bytes_mut() { fn test_bytes_mut() {
let b = BytesMut::from("test"); let b = BytesMut::from("test");
assert_eq!(Binary::from(b.clone()).len(), 4); assert_eq!(Binary::from(b.clone()).len(), 4);
assert_eq!(Binary::from(b).as_ref(), b"test"); assert_eq!(Binary::from(b).as_ref(), b"test");
} }

View File

@ -1,36 +1,35 @@
use std::{fmt, mem, io, time};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::net::Shutdown;
use std::time::{Duration, Instant};
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::net::Shutdown;
use std::rc::Rc;
use std::time::{Duration, Instant};
use std::{fmt, io, mem, time};
use actix::{fut, Actor, ActorFuture, Arbiter, Context, AsyncContext, use actix::actors::{Connect as ResolveConnect, Connector, ConnectorError};
Recipient, Syn, Handler, Message, ActorResponse,
Supervised, ContextFutureSpawner};
use actix::registry::ArbiterService;
use actix::fut::WrapFuture; use actix::fut::WrapFuture;
use actix::actors::{Connector, ConnectorError, Connect as ResolveConnect}; use actix::registry::ArbiterService;
use actix::{fut, Actor, ActorFuture, ActorResponse, Arbiter, AsyncContext, Context,
ContextFutureSpawner, Handler, Message, Recipient, Supervised, Syn};
use http::{Uri, HttpTryFrom, Error as HttpError}; use futures::task::{current as current_task, Task};
use futures::{Async, Future, Poll};
use futures::task::{Task, current as current_task};
use futures::unsync::oneshot; use futures::unsync::oneshot;
use tokio_io::{AsyncRead, AsyncWrite}; use futures::{Async, Future, Poll};
use http::{Error as HttpError, HttpTryFrom, Uri};
use tokio_core::reactor::Timeout; use tokio_core::reactor::Timeout;
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use openssl::ssl::{SslMethod, SslConnector, Error as OpensslError}; use openssl::ssl::{Error as OpensslError, SslConnector, SslMethod};
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use tokio_openssl::SslConnectorExt; use tokio_openssl::SslConnectorExt;
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
use native_tls::{TlsConnector, Error as TlsError}; use native_tls::{Error as TlsError, TlsConnector};
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
use tokio_tls::TlsConnectorExt; use tokio_tls::TlsConnectorExt;
use {HAS_OPENSSL, HAS_TLS};
use server::IoStream; use server::IoStream;
use {HAS_OPENSSL, HAS_TLS};
/// Client connector usage stats /// Client connector usage stats
#[derive(Default, Message)] #[derive(Default, Message)]
@ -54,7 +53,10 @@ pub struct Connect {
impl Connect { impl Connect {
/// Create `Connect` message for specified `Uri` /// Create `Connect` message for specified `Uri`
pub fn new<U>(uri: U) -> Result<Connect, HttpError> where Uri: HttpTryFrom<U> { pub fn new<U>(uri: U) -> Result<Connect, HttpError>
where
Uri: HttpTryFrom<U>,
{
Ok(Connect { Ok(Connect {
uri: Uri::try_from(uri).map_err(|e| e.into())?, uri: Uri::try_from(uri).map_err(|e| e.into())?,
wait_timeout: Duration::from_secs(5), wait_timeout: Duration::from_secs(5),
@ -92,13 +94,13 @@ pub struct Pause {
impl Pause { impl Pause {
/// Create message with pause duration parameter /// Create message with pause duration parameter
pub fn new(time: Duration) -> Pause { pub fn new(time: Duration) -> Pause {
Pause{time: Some(time)} Pause { time: Some(time) }
} }
} }
impl Default for Pause { impl Default for Pause {
fn default() -> Pause { fn default() -> Pause {
Pause{time: None} Pause { time: None }
} }
} }
@ -114,21 +116,21 @@ pub struct Resume;
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum ClientConnectorError { pub enum ClientConnectorError {
/// Invalid URL /// Invalid URL
#[fail(display="Invalid URL")] #[fail(display = "Invalid URL")]
InvalidUrl, InvalidUrl,
/// SSL feature is not enabled /// SSL feature is not enabled
#[fail(display="SSL is not supported")] #[fail(display = "SSL is not supported")]
SslIsNotSupported, SslIsNotSupported,
/// SSL error /// SSL error
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
SslError(#[cause] OpensslError), SslError(#[cause] OpensslError),
/// SSL error /// SSL error
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
SslError(#[cause] TlsError), SslError(#[cause] TlsError),
/// Connection error /// Connection error
@ -152,7 +154,7 @@ impl From<ConnectorError> for ClientConnectorError {
fn from(err: ConnectorError) -> ClientConnectorError { fn from(err: ConnectorError) -> ClientConnectorError {
match err { match err {
ConnectorError::Timeout => ClientConnectorError::Timeout, ConnectorError::Timeout => ClientConnectorError::Timeout,
_ => ClientConnectorError::Connector(err) _ => ClientConnectorError::Connector(err),
} }
} }
} }
@ -166,9 +168,9 @@ struct Waiter {
/// `ClientConnector` type is responsible for transport layer of a /// `ClientConnector` type is responsible for transport layer of a
/// client connection. /// client connection.
pub struct ClientConnector { pub struct ClientConnector {
#[cfg(all(feature="alpn"))] #[cfg(all(feature = "alpn"))]
connector: SslConnector, connector: SslConnector,
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
connector: TlsConnector, connector: TlsConnector,
stats: ClientConnectorStats, stats: ClientConnectorStats,
@ -207,12 +209,12 @@ impl Default for ClientConnector {
fn default() -> ClientConnector { fn default() -> ClientConnector {
let _modified = Rc::new(Cell::new(false)); let _modified = Rc::new(Cell::new(false));
#[cfg(all(feature="alpn"))] #[cfg(all(feature = "alpn"))]
{ {
let builder = SslConnector::builder(SslMethod::tls()).unwrap(); let builder = SslConnector::builder(SslMethod::tls()).unwrap();
ClientConnector::with_connector(builder.build()) ClientConnector::with_connector(builder.build())
} }
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
{ {
let builder = TlsConnector::builder().unwrap(); let builder = TlsConnector::builder().unwrap();
ClientConnector { ClientConnector {
@ -235,29 +237,29 @@ impl Default for ClientConnector {
} }
} }
#[cfg(not(any(feature="alpn", feature="tls")))] #[cfg(not(any(feature = "alpn", feature = "tls")))]
ClientConnector {stats: ClientConnectorStats::default(), ClientConnector {
subscriber: None, stats: ClientConnectorStats::default(),
pool: Rc::new(Pool::new(Rc::clone(&_modified))), subscriber: None,
pool_modified: _modified, pool: Rc::new(Pool::new(Rc::clone(&_modified))),
conn_lifetime: Duration::from_secs(15), pool_modified: _modified,
conn_keep_alive: Duration::from_secs(75), conn_lifetime: Duration::from_secs(15),
limit: 100, conn_keep_alive: Duration::from_secs(75),
limit_per_host: 0, limit: 100,
acquired: 0, limit_per_host: 0,
acquired_per_host: HashMap::new(), acquired: 0,
available: HashMap::new(), acquired_per_host: HashMap::new(),
to_close: Vec::new(), available: HashMap::new(),
waiters: HashMap::new(), to_close: Vec::new(),
wait_timeout: None, waiters: HashMap::new(),
paused: None, wait_timeout: None,
paused: None,
} }
} }
} }
impl ClientConnector { impl ClientConnector {
#[cfg(feature = "alpn")]
#[cfg(feature="alpn")]
/// Create `ClientConnector` actor with custom `SslConnector` instance. /// Create `ClientConnector` actor with custom `SslConnector` instance.
/// ///
/// By default `ClientConnector` uses very a simple SSL configuration. /// By default `ClientConnector` uses very a simple SSL configuration.
@ -369,20 +371,19 @@ impl ClientConnector {
// check limits // check limits
if self.limit > 0 { if self.limit > 0 {
if self.acquired >= self.limit { if self.acquired >= self.limit {
return Acquire::NotAvailable return Acquire::NotAvailable;
} }
if self.limit_per_host > 0 { if self.limit_per_host > 0 {
if let Some(per_host) = self.acquired_per_host.get(key) { if let Some(per_host) = self.acquired_per_host.get(key) {
if self.limit_per_host >= *per_host { if self.limit_per_host >= *per_host {
return Acquire::NotAvailable return Acquire::NotAvailable;
} }
} }
} }
} } else if self.limit_per_host > 0 {
else if self.limit_per_host > 0 {
if let Some(per_host) = self.acquired_per_host.get(key) { if let Some(per_host) = self.acquired_per_host.get(key) {
if self.limit_per_host >= *per_host { if self.limit_per_host >= *per_host {
return Acquire::NotAvailable return Acquire::NotAvailable;
} }
} }
} }
@ -408,11 +409,11 @@ impl ClientConnector {
Ok(n) if n > 0 => { Ok(n) if n > 0 => {
self.stats.closed += 1; self.stats.closed += 1;
self.to_close.push(conn); self.to_close.push(conn);
continue continue;
}, }
Ok(_) | Err(_) => continue, Ok(_) | Err(_) => continue,
} }
return Acquire::Acquired(conn) return Acquire::Acquired(conn);
} }
} }
} }
@ -421,25 +422,25 @@ impl ClientConnector {
fn reserve(&mut self, key: &Key) { fn reserve(&mut self, key: &Key) {
self.acquired += 1; self.acquired += 1;
let per_host = let per_host = if let Some(per_host) = self.acquired_per_host.get(key) {
if let Some(per_host) = self.acquired_per_host.get(key) { *per_host
*per_host } else {
} else { 0
0 };
}; self.acquired_per_host
self.acquired_per_host.insert(key.clone(), per_host + 1); .insert(key.clone(), per_host + 1);
} }
fn release_key(&mut self, key: &Key) { fn release_key(&mut self, key: &Key) {
self.acquired -= 1; self.acquired -= 1;
let per_host = let per_host = if let Some(per_host) = self.acquired_per_host.get(key) {
if let Some(per_host) = self.acquired_per_host.get(key) { *per_host
*per_host } else {
} else { return;
return };
};
if per_host > 1 { if per_host > 1 {
self.acquired_per_host.insert(key.clone(), per_host - 1); self.acquired_per_host
.insert(key.clone(), per_host - 1);
} else { } else {
self.acquired_per_host.remove(key); self.acquired_per_host.remove(key);
} }
@ -472,7 +473,8 @@ impl ClientConnector {
// check connection lifetime and the return to available pool // check connection lifetime and the return to available pool
if (now - conn.ts) < self.conn_lifetime { if (now - conn.ts) < self.conn_lifetime {
self.available.entry(conn.key.clone()) self.available
.entry(conn.key.clone())
.or_insert_with(VecDeque::new) .or_insert_with(VecDeque::new)
.push_back(Conn(Instant::now(), conn)); .push_back(Conn(Instant::now(), conn));
} }
@ -490,7 +492,7 @@ impl ClientConnector {
self.to_close.push(conn); self.to_close.push(conn);
self.stats.closed += 1; self.stats.closed += 1;
} else { } else {
break break;
} }
} }
} }
@ -503,7 +505,7 @@ impl ClientConnector {
Ok(Async::NotReady) => idx += 1, Ok(Async::NotReady) => idx += 1,
_ => { _ => {
self.to_close.swap_remove(idx); self.to_close.swap_remove(idx);
}, }
} }
} }
} }
@ -514,7 +516,9 @@ impl ClientConnector {
fn collect_periodic(&mut self, ctx: &mut Context<Self>) { fn collect_periodic(&mut self, ctx: &mut Context<Self>) {
self.collect(true); self.collect(true);
// re-schedule next collect period // re-schedule next collect period
ctx.run_later(Duration::from_secs(1), |act, ctx| act.collect_periodic(ctx)); ctx.run_later(Duration::from_secs(1), |act, ctx| {
act.collect_periodic(ctx)
});
// send stats // send stats
let stats = mem::replace(&mut self.stats, ClientConnectorStats::default()); let stats = mem::replace(&mut self.stats, ClientConnectorStats::default());
@ -555,27 +559,34 @@ impl ClientConnector {
fn install_wait_timeout(&mut self, time: Instant) { fn install_wait_timeout(&mut self, time: Instant) {
if let Some(ref mut wait) = self.wait_timeout { if let Some(ref mut wait) = self.wait_timeout {
if wait.0 < time { if wait.0 < time {
return return;
} }
} }
let mut timeout = Timeout::new(time-Instant::now(), Arbiter::handle()).unwrap(); let mut timeout =
Timeout::new(time - Instant::now(), Arbiter::handle()).unwrap();
let _ = timeout.poll(); let _ = timeout.poll();
self.wait_timeout = Some((time, timeout)); self.wait_timeout = Some((time, timeout));
} }
fn wait_for(&mut self, key: Key, fn wait_for(
wait: Duration, conn_timeout: Duration) &mut self, key: Key, wait: Duration, conn_timeout: Duration
-> oneshot::Receiver<Result<Connection, ClientConnectorError>> ) -> oneshot::Receiver<Result<Connection, ClientConnectorError>> {
{
// connection is not available, wait // connection is not available, wait
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let wait = Instant::now() + wait; let wait = Instant::now() + wait;
self.install_wait_timeout(wait); self.install_wait_timeout(wait);
let waiter = Waiter{ tx, wait, conn_timeout }; let waiter = Waiter {
self.waiters.entry(key).or_insert_with(VecDeque::new).push_back(waiter); tx,
wait,
conn_timeout,
};
self.waiters
.entry(key)
.or_insert_with(VecDeque::new)
.push_back(waiter);
rx rx
} }
} }
@ -617,21 +628,23 @@ impl Handler<Connect> for ClientConnector {
// host name is required // host name is required
if uri.host().is_none() { if uri.host().is_none() {
return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)) return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl));
} }
// supported protocols // supported protocols
let proto = match uri.scheme_part() { let proto = match uri.scheme_part() {
Some(scheme) => match Protocol::from(scheme.as_str()) { Some(scheme) => match Protocol::from(scheme.as_str()) {
Some(proto) => proto, Some(proto) => proto,
None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)), None => {
return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl))
}
}, },
None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)), None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)),
}; };
// check ssl availability // check ssl availability
if proto.is_secure() && !HAS_OPENSSL && !HAS_TLS { if proto.is_secure() && !HAS_OPENSSL && !HAS_TLS {
return ActorResponse::reply(Err(ClientConnectorError::SslIsNotSupported)) return ActorResponse::reply(Err(ClientConnectorError::SslIsNotSupported));
} }
// check if pool has task reference // check if pool has task reference
@ -641,7 +654,11 @@ impl Handler<Connect> for ClientConnector {
let host = uri.host().unwrap().to_owned(); let host = uri.host().unwrap().to_owned();
let port = uri.port().unwrap_or_else(|| proto.port()); let port = uri.port().unwrap_or_else(|| proto.port());
let key = Key {host, port, ssl: proto.is_secure()}; let key = Key {
host,
port,
ssl: proto.is_secure(),
};
// check pause state // check pause state
if self.paused.is_some() { if self.paused.is_some() {
@ -653,7 +670,8 @@ impl Handler<Connect> for ClientConnector {
.and_then(|res, _, _| match res { .and_then(|res, _, _| match res {
Ok(conn) => fut::ok(conn), Ok(conn) => fut::ok(conn),
Err(err) => fut::err(err), Err(err) => fut::err(err),
})); }),
);
} }
// acquire connection // acquire connection
@ -663,8 +681,8 @@ impl Handler<Connect> for ClientConnector {
// use existing connection // use existing connection
conn.pool = Some(AcquiredConn(key, Some(Rc::clone(&self.pool)))); conn.pool = Some(AcquiredConn(key, Some(Rc::clone(&self.pool))));
self.stats.reused += 1; self.stats.reused += 1;
return ActorResponse::async(fut::ok(conn)) return ActorResponse::async(fut::ok(conn));
}, }
Acquire::NotAvailable => { Acquire::NotAvailable => {
// connection is not available, wait // connection is not available, wait
let rx = self.wait_for(key, wait_timeout, conn_timeout); let rx = self.wait_for(key, wait_timeout, conn_timeout);
@ -675,106 +693,131 @@ impl Handler<Connect> for ClientConnector {
.and_then(|res, _, _| match res { .and_then(|res, _, _| match res {
Ok(conn) => fut::ok(conn), Ok(conn) => fut::ok(conn),
Err(err) => fut::err(err), Err(err) => fut::err(err),
})); }),
);
} }
Acquire::Available => { Acquire::Available => Some(Rc::clone(&self.pool)),
Some(Rc::clone(&self.pool))
},
} }
} else { } else {
None None
}; };
let conn = AcquiredConn(key, pool); let conn = AcquiredConn(key, pool);
{ {
ActorResponse::async( ActorResponse::async(
Connector::from_registry() Connector::from_registry()
.send(ResolveConnect::host_and_port(&conn.0.host, port) .send(
.timeout(conn_timeout)) ResolveConnect::host_and_port(&conn.0.host, port)
.into_actor(self) .timeout(conn_timeout),
.map_err(|_, _, _| ClientConnectorError::Disconnected) )
.and_then(move |res, act, _| { .into_actor(self)
#[cfg(feature="alpn")] .map_err(|_, _, _| ClientConnectorError::Disconnected)
match res { .and_then(move |res, act, _| {
Err(err) => { #[cfg(feature = "alpn")]
act.stats.opened += 1; match res {
fut::Either::B(fut::err(err.into())) Err(err) => {
}, act.stats.opened += 1;
Ok(stream) => { fut::Either::B(fut::err(err.into()))
act.stats.opened += 1; }
if proto.is_secure() { Ok(stream) => {
fut::Either::A( act.stats.opened += 1;
act.connector.connect_async(&conn.0.host, stream) if proto.is_secure() {
.map_err(ClientConnectorError::SslError) fut::Either::A(
.map(|stream| Connection::new( act.connector
conn.0.clone(), Some(conn), Box::new(stream))) .connect_async(&conn.0.host, stream)
.into_actor(act)) .map_err(ClientConnectorError::SslError)
} else { .map(|stream| {
fut::Either::B(fut::ok( Connection::new(
Connection::new( conn.0.clone(),
conn.0.clone(), Some(conn), Box::new(stream)))) Some(conn),
Box::new(stream),
)
})
.into_actor(act),
)
} else {
fut::Either::B(fut::ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)))
}
}
} }
}
}
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg(all(feature = "tls", not(feature = "alpn")))]
match res { match res {
Err(err) => { Err(err) => {
act.stats.opened += 1; act.stats.opened += 1;
fut::Either::B(fut::err(err.into())) fut::Either::B(fut::err(err.into()))
}, }
Ok(stream) => { Ok(stream) => {
act.stats.opened += 1; act.stats.opened += 1;
if proto.is_secure() { if proto.is_secure() {
fut::Either::A( fut::Either::A(
act.connector.connect_async(&conn.0.host, stream) act.connector
.map_err(ClientConnectorError::SslError) .connect_async(&conn.0.host, stream)
.map(|stream| Connection::new( .map_err(ClientConnectorError::SslError)
conn.0.clone(), Some(conn), Box::new(stream))) .map(|stream| {
.into_actor(act)) Connection::new(
} else { conn.0.clone(),
fut::Either::B(fut::ok( Some(conn),
Connection::new( Box::new(stream),
conn.0.clone(), Some(conn), Box::new(stream)))) )
})
.into_actor(act),
)
} else {
fut::Either::B(fut::ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)))
}
}
} }
}
}
#[cfg(not(any(feature="alpn", feature="tls")))] #[cfg(not(any(feature = "alpn", feature = "tls")))]
match res { match res {
Err(err) => { Err(err) => {
act.stats.opened += 1; act.stats.opened += 1;
fut::err(err.into()) fut::err(err.into())
}, }
Ok(stream) => { Ok(stream) => {
act.stats.opened += 1; act.stats.opened += 1;
if proto.is_secure() { if proto.is_secure() {
fut::err(ClientConnectorError::SslIsNotSupported) fut::err(ClientConnectorError::SslIsNotSupported)
} else { } else {
fut::ok(Connection::new( fut::ok(Connection::new(
conn.0.clone(), Some(conn), Box::new(stream))) conn.0.clone(),
Some(conn),
Box::new(stream),
))
}
}
} }
} }),
} )
})) }
}
} }
} }
struct Maintenance; struct Maintenance;
impl fut::ActorFuture for Maintenance impl fut::ActorFuture for Maintenance {
{
type Item = (); type Item = ();
type Error = (); type Error = ();
type Actor = ClientConnector; type Actor = ClientConnector;
fn poll(&mut self, act: &mut ClientConnector, ctx: &mut Context<ClientConnector>) fn poll(
-> Poll<Self::Item, Self::Error> &mut self, act: &mut ClientConnector, ctx: &mut Context<ClientConnector>
{ ) -> Poll<Self::Item, Self::Error> {
// check pause duration // check pause duration
let done = if let Some(Some(ref pause)) = act.paused { let done = if let Some(Some(ref pause)) = act.paused {
pause.0 <= Instant::now() } else { false }; pause.0 <= Instant::now()
} else {
false
};
if done { if done {
act.paused.take(); act.paused.take();
} }
@ -788,128 +831,151 @@ impl fut::ActorFuture for Maintenance
act.collect_waiters(); act.collect_waiters();
// check waiters // check waiters
let tmp: &mut ClientConnector = unsafe{mem::transmute(act as &mut _)}; let tmp: &mut ClientConnector = unsafe { mem::transmute(act as &mut _) };
for (key, waiters) in &mut tmp.waiters { for (key, waiters) in &mut tmp.waiters {
while let Some(waiter) = waiters.pop_front() { while let Some(waiter) = waiters.pop_front() {
if waiter.tx.is_canceled() { continue } if waiter.tx.is_canceled() {
continue;
}
match act.acquire(key) { match act.acquire(key) {
Acquire::Acquired(mut conn) => { Acquire::Acquired(mut conn) => {
// use existing connection // use existing connection
act.stats.reused += 1; act.stats.reused += 1;
conn.pool = Some( conn.pool =
AcquiredConn(key.clone(), Some(Rc::clone(&act.pool)))); Some(AcquiredConn(key.clone(), Some(Rc::clone(&act.pool))));
let _ = waiter.tx.send(Ok(conn)); let _ = waiter.tx.send(Ok(conn));
}, }
Acquire::NotAvailable => { Acquire::NotAvailable => {
waiters.push_front(waiter); waiters.push_front(waiter);
break break;
} }
Acquire::Available => Acquire::Available => {
{ let conn = AcquiredConn(key.clone(), Some(Rc::clone(&act.pool)));
let conn = AcquiredConn(key.clone(), Some(Rc::clone(&act.pool)));
fut::WrapFuture::<ClientConnector>::actfuture( fut::WrapFuture::<ClientConnector>::actfuture(
Connector::from_registry() Connector::from_registry().send(
.send(ResolveConnect::host_and_port(&conn.0.host, conn.0.port) ResolveConnect::host_and_port(&conn.0.host, conn.0.port)
.timeout(waiter.conn_timeout))) .timeout(waiter.conn_timeout),
.map_err(|_, _, _| ()) ),
.and_then(move |res, act, _| { ).map_err(|_, _, _| ())
#[cfg(feature="alpn")] .and_then(move |res, act, _| {
match res { #[cfg_attr(rustfmt, rustfmt_skip)]
Err(err) => { #[cfg(feature = "alpn")]
act.stats.errors += 1; match res {
let _ = waiter.tx.send(Err(err.into())); Err(err) => {
fut::Either::B(fut::err(())) act.stats.errors += 1;
}, let _ = waiter.tx.send(Err(err.into()));
Ok(stream) => { fut::Either::B(fut::err(()))
act.stats.opened += 1; }
if conn.0.ssl { Ok(stream) => {
fut::Either::A( act.stats.opened += 1;
act.connector.connect_async(&key.host, stream) if conn.0.ssl {
.then(move |res| { fut::Either::A(
match res { act.connector
Err(e) => { .connect_async(&key.host, stream)
let _ = waiter.tx.send(Err( .then(move |res| {
ClientConnectorError::SslError(e))); match res {
}, Err(e) => {
Ok(stream) => { let _ = waiter.tx.send(
let _ = waiter.tx.send(Ok( Err(ClientConnectorError::SslError(e)));
Connection::new( }
conn.0.clone(), Ok(stream) => {
Some(conn), Box::new(stream)))); let _ = waiter.tx.send(
} Ok(Connection::new(
} conn.0.clone(),
Ok(()) Some(conn),
}) Box::new(stream),
.actfuture()) )),
} else { );
let _ = waiter.tx.send(Ok(Connection::new( }
conn.0.clone(), Some(conn), Box::new(stream)))); }
fut::Either::B(fut::ok(())) Ok(())
} })
} .actfuture(),
} )
} else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)));
fut::Either::B(fut::ok(()))
}
}
}
#[cfg(all(feature="tls", not(feature="alpn")))] #[cfg_attr(rustfmt, rustfmt_skip)]
match res { #[cfg(all(feature = "tls", not(feature = "alpn")))]
Err(err) => { match res {
act.stats.errors += 1; Err(err) => {
let _ = waiter.tx.send(Err(err.into())); act.stats.errors += 1;
fut::Either::B(fut::err(())) let _ = waiter.tx.send(Err(err.into()));
}, fut::Either::B(fut::err(()))
Ok(stream) => { }
act.stats.opened += 1; Ok(stream) => {
if conn.0.ssl { act.stats.opened += 1;
fut::Either::A( if conn.0.ssl {
act.connector.connect_async(&conn.0.host, stream) fut::Either::A(
.then(|res| { act.connector
match res { .connect_async(&conn.0.host, stream)
Err(e) => { .then(|res| {
let _ = waiter.tx.send(Err( match res {
ClientConnectorError::SslError(e))); Err(e) => {
}, let _ = waiter.tx.send(Err(
Ok(stream) => { ClientConnectorError::SslError(e),
let _ = waiter.tx.send(Ok( ));
Connection::new( }
conn.0.clone(), Some(conn), Ok(stream) => {
Box::new(stream)))); let _ = waiter.tx.send(
} Ok(Connection::new(
} conn.0.clone(), Some(conn),
Ok(()) Box::new(stream),
}) )),
.into_actor(act)) );
} else { }
let _ = waiter.tx.send(Ok(Connection::new( }
conn.0.clone(), Some(conn), Box::new(stream)))); Ok(())
fut::Either::B(fut::ok(())) })
} .into_actor(act),
} )
} } else {
let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(),
Some(conn),
Box::new(stream),
)));
fut::Either::B(fut::ok(()))
}
}
}
#[cfg(not(any(feature="alpn", feature="tls")))] #[cfg_attr(rustfmt, rustfmt_skip)]
match res { #[cfg(not(any(feature = "alpn", feature = "tls")))]
Err(err) => { match res {
act.stats.errors += 1; Err(err) => {
let _ = waiter.tx.send(Err(err.into())); act.stats.errors += 1;
fut::err(()) let _ = waiter.tx.send(Err(err.into()));
}, fut::err(())
Ok(stream) => { }
act.stats.opened += 1; Ok(stream) => {
if conn.0.ssl { act.stats.opened += 1;
let _ = waiter.tx.send( if conn.0.ssl {
Err(ClientConnectorError::SslIsNotSupported)); let _ = waiter.tx.send(Err(ClientConnectorError::SslIsNotSupported));
} else { } else {
let _ = waiter.tx.send(Ok(Connection::new( let _ = waiter.tx.send(Ok(Connection::new(
conn.0.clone(), Some(conn), Box::new(stream)))); conn.0.clone(),
}; Some(conn),
fut::ok(()) Box::new(stream),
}, )));
} };
}) fut::ok(())
.spawn(ctx); }
} }
})
.spawn(ctx);
}
} }
} }
} }
@ -954,7 +1020,7 @@ impl Protocol {
fn port(&self) -> u16 { fn port(&self) -> u16 {
match *self { match *self {
Protocol::Http | Protocol::Ws => 80, Protocol::Http | Protocol::Ws => 80,
Protocol::Https | Protocol::Wss => 443 Protocol::Https | Protocol::Wss => 443,
} }
} }
} }
@ -968,7 +1034,11 @@ struct Key {
impl Key { impl Key {
fn empty() -> Key { fn empty() -> Key {
Key{host: String::new(), port: 0, ssl: false} Key {
host: String::new(),
port: 0,
ssl: false,
}
} }
} }
@ -1035,7 +1105,10 @@ impl Pool {
if self.to_close.borrow().is_empty() { if self.to_close.borrow().is_empty() {
None None
} else { } else {
Some(mem::replace(&mut *self.to_close.borrow_mut(), Vec::new())) Some(mem::replace(
&mut *self.to_close.borrow_mut(),
Vec::new(),
))
} }
} }
@ -1043,7 +1116,10 @@ impl Pool {
if self.to_release.borrow().is_empty() { if self.to_release.borrow().is_empty() {
None None
} else { } else {
Some(mem::replace(&mut *self.to_release.borrow_mut(), Vec::new())) Some(mem::replace(
&mut *self.to_release.borrow_mut(),
Vec::new(),
))
} }
} }
@ -1072,7 +1148,6 @@ impl Pool {
} }
} }
pub struct Connection { pub struct Connection {
key: Key, key: Key,
stream: Box<IoStream>, stream: Box<IoStream>,
@ -1088,7 +1163,12 @@ impl fmt::Debug for Connection {
impl Connection { impl Connection {
fn new(key: Key, pool: Option<AcquiredConn>, stream: Box<IoStream>) -> Self { fn new(key: Key, pool: Option<AcquiredConn>, stream: Box<IoStream>) -> Self {
Connection {key, stream, pool, ts: Instant::now()} Connection {
key,
stream,
pool,
ts: Instant::now(),
}
} }
pub fn stream(&mut self) -> &mut IoStream { pub fn stream(&mut self) -> &mut IoStream {

View File

@ -28,33 +28,30 @@
//! ``` //! ```
mod connector; mod connector;
mod parser; mod parser;
mod pipeline;
mod request; mod request;
mod response; mod response;
mod pipeline;
mod writer; mod writer;
pub use self::connector::{ClientConnector, ClientConnectorError, ClientConnectorStats,
Connect, Connection, Pause, Resume};
pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError};
pub use self::pipeline::{SendRequest, SendRequestError}; pub use self::pipeline::{SendRequest, SendRequestError};
pub use self::request::{ClientRequest, ClientRequestBuilder}; pub use self::request::{ClientRequest, ClientRequestBuilder};
pub use self::response::ClientResponse; pub use self::response::ClientResponse;
pub use self::connector::{
Connect, Pause, Resume,
Connection, ClientConnector, ClientConnectorError, ClientConnectorStats};
pub(crate) use self::writer::HttpClientWriter; pub(crate) use self::writer::HttpClientWriter;
pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError};
use error::ResponseError; use error::ResponseError;
use http::Method; use http::Method;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
/// Convert `SendRequestError` to a `HttpResponse` /// Convert `SendRequestError` to a `HttpResponse`
impl ResponseError for SendRequestError { impl ResponseError for SendRequestError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match *self { match *self {
SendRequestError::Connector(_) => HttpResponse::BadGateway(), SendRequestError::Connector(_) => HttpResponse::BadGateway(),
_ => HttpResponse::InternalServerError(), _ => HttpResponse::InternalServerError(),
} }.into()
.into()
} }
} }

View File

@ -1,14 +1,14 @@
use std::mem;
use httparse;
use http::{Version, HttpTryFrom, HeaderMap, StatusCode};
use http::header::{self, HeaderName, HeaderValue};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Poll, Async}; use futures::{Async, Poll};
use http::header::{self, HeaderName, HeaderValue};
use http::{HeaderMap, HttpTryFrom, StatusCode, Version};
use httparse;
use std::mem;
use error::{ParseError, PayloadError}; use error::{ParseError, PayloadError};
use server::h1::{chunked, Decoder};
use server::{utils, IoStream}; use server::{utils, IoStream};
use server::h1::{Decoder, chunked};
use super::ClientResponse; use super::ClientResponse;
use super::response::ClientMessage; use super::response::ClientMessage;
@ -24,28 +24,26 @@ pub struct HttpResponseParser {
#[derive(Debug, Fail)] #[derive(Debug, Fail)]
pub enum HttpResponseParserError { pub enum HttpResponseParserError {
/// Server disconnected /// Server disconnected
#[fail(display="Server disconnected")] #[fail(display = "Server disconnected")]
Disconnect, Disconnect,
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Error(#[cause] ParseError), Error(#[cause] ParseError),
} }
impl HttpResponseParser { impl HttpResponseParser {
pub fn parse<T>(
pub fn parse<T>(&mut self, io: &mut T, buf: &mut BytesMut) &mut self, io: &mut T, buf: &mut BytesMut
-> Poll<ClientResponse, HttpResponseParserError> ) -> Poll<ClientResponse, HttpResponseParserError>
where T: IoStream where
T: IoStream,
{ {
// if buf is empty parse_message will always return NotReady, let's avoid that // if buf is empty parse_message will always return NotReady, let's avoid that
if buf.is_empty() { if buf.is_empty() {
match utils::read_from_io(io, buf) { match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) => Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect),
return Err(HttpResponseParserError::Disconnect),
Ok(Async::Ready(_)) => (), Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => Ok(Async::NotReady) => return Ok(Async::NotReady),
return Ok(Async::NotReady), Err(err) => return Err(HttpResponseParserError::Error(err.into())),
Err(err) =>
return Err(HttpResponseParserError::Error(err.into()))
} }
} }
@ -56,27 +54,31 @@ impl HttpResponseParser {
Async::Ready((msg, decoder)) => { Async::Ready((msg, decoder)) => {
self.decoder = decoder; self.decoder = decoder;
return Ok(Async::Ready(msg)); return Ok(Async::Ready(msg));
}, }
Async::NotReady => { Async::NotReady => {
if buf.capacity() >= MAX_BUFFER_SIZE { if buf.capacity() >= MAX_BUFFER_SIZE {
return Err(HttpResponseParserError::Error(ParseError::TooLarge)); return Err(HttpResponseParserError::Error(ParseError::TooLarge));
} }
match utils::read_from_io(io, buf) { match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) => Ok(Async::Ready(0)) => {
return Err(HttpResponseParserError::Disconnect), return Err(HttpResponseParserError::Disconnect)
}
Ok(Async::Ready(_)) => (), Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => Err(err) => {
return Err(HttpResponseParserError::Error(err.into())), return Err(HttpResponseParserError::Error(err.into()))
}
} }
}, }
} }
} }
} }
pub fn parse_payload<T>(&mut self, io: &mut T, buf: &mut BytesMut) pub fn parse_payload<T>(
-> Poll<Option<Bytes>, PayloadError> &mut self, io: &mut T, buf: &mut BytesMut
where T: IoStream ) -> Poll<Option<Bytes>, PayloadError>
where
T: IoStream,
{ {
if self.decoder.is_some() { if self.decoder.is_some() {
loop { loop {
@ -89,18 +91,17 @@ impl HttpResponseParser {
}; };
match self.decoder.as_mut().unwrap().decode(buf) { match self.decoder.as_mut().unwrap().decode(buf) {
Ok(Async::Ready(Some(b))) => Ok(Async::Ready(Some(b))) => return Ok(Async::Ready(Some(b))),
return Ok(Async::Ready(Some(b))),
Ok(Async::Ready(None)) => { Ok(Async::Ready(None)) => {
self.decoder.take(); self.decoder.take();
return Ok(Async::Ready(None)) return Ok(Async::Ready(None));
} }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
if not_ready { if not_ready {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
if stream_finished { if stream_finished {
return Err(PayloadError::Incomplete) return Err(PayloadError::Incomplete);
} }
} }
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
@ -111,16 +112,19 @@ impl HttpResponseParser {
} }
} }
fn parse_message(buf: &mut BytesMut) fn parse_message(
-> Poll<(ClientResponse, Option<Decoder>), ParseError> buf: &mut BytesMut
{ ) -> Poll<(ClientResponse, Option<Decoder>), ParseError> {
// Parse http message // Parse http message
let bytes_ptr = buf.as_ref().as_ptr() as usize; let bytes_ptr = buf.as_ref().as_ptr() as usize;
let mut headers: [httparse::Header; MAX_HEADERS] = let mut headers: [httparse::Header; MAX_HEADERS] =
unsafe{mem::uninitialized()}; unsafe { mem::uninitialized() };
let (len, version, status, headers_len) = { let (len, version, status, headers_len) = {
let b = unsafe{ let b: &[u8] = buf; mem::transmute(b) }; let b = unsafe {
let b: &[u8] = buf;
mem::transmute(b)
};
let mut resp = httparse::Response::new(&mut headers); let mut resp = httparse::Response::new(&mut headers);
match resp.parse(b)? { match resp.parse(b)? {
httparse::Status::Complete(len) => { httparse::Status::Complete(len) => {
@ -147,10 +151,11 @@ impl HttpResponseParser {
let v_start = header.value.as_ptr() as usize - bytes_ptr; let v_start = header.value.as_ptr() as usize - bytes_ptr;
let v_end = v_start + header.value.len(); let v_end = v_start + header.value.len();
let value = unsafe { let value = unsafe {
HeaderValue::from_shared_unchecked(slice.slice(v_start, v_end)) }; HeaderValue::from_shared_unchecked(slice.slice(v_start, v_end))
};
hdrs.append(name, value); hdrs.append(name, value);
} else { } else {
return Err(ParseError::Header) return Err(ParseError::Header);
} }
} }
@ -163,11 +168,11 @@ impl HttpResponseParser {
Some(Decoder::length(len)) Some(Decoder::length(len))
} else { } else {
debug!("illegal Content-Length: {:?}", len); debug!("illegal Content-Length: {:?}", len);
return Err(ParseError::Header) return Err(ParseError::Header);
} }
} else { } else {
debug!("illegal Content-Length: {:?}", len); debug!("illegal Content-Length: {:?}", len);
return Err(ParseError::Header) return Err(ParseError::Header);
} }
} else if chunked(&hdrs)? { } else if chunked(&hdrs)? {
// Chunked encoding // Chunked encoding
@ -177,15 +182,25 @@ impl HttpResponseParser {
}; };
if let Some(decoder) = decoder { if let Some(decoder) = decoder {
Ok(Async::Ready( Ok(Async::Ready((
(ClientResponse::new( ClientResponse::new(ClientMessage {
ClientMessage{status, version, status,
headers: hdrs, cookies: None}), Some(decoder)))) version,
headers: hdrs,
cookies: None,
}),
Some(decoder),
)))
} else { } else {
Ok(Async::Ready( Ok(Async::Ready((
(ClientResponse::new( ClientResponse::new(ClientMessage {
ClientMessage{status, version, status,
headers: hdrs, cookies: None}), None))) version,
headers: hdrs,
cookies: None,
}),
None,
)))
} }
} }
} }

View File

@ -1,26 +1,26 @@
use std::{io, mem};
use std::time::Duration;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use http::header::CONTENT_ENCODING;
use futures::{Async, Future, Poll};
use futures::unsync::oneshot; use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use http::header::CONTENT_ENCODING;
use std::time::Duration;
use std::{io, mem};
use tokio_core::reactor::Timeout; use tokio_core::reactor::Timeout;
use actix::prelude::*; use actix::prelude::*;
use error::Error; use super::HttpClientWriter;
use super::{ClientConnector, ClientConnectorError, Connect, Connection};
use super::{ClientRequest, ClientResponse};
use super::{HttpResponseParser, HttpResponseParserError};
use body::{Body, BodyStream}; use body::{Body, BodyStream};
use context::{Frame, ActorHttpContext}; use context::{ActorHttpContext, Frame};
use error::Error;
use error::PayloadError;
use header::ContentEncoding; use header::ContentEncoding;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use error::PayloadError;
use server::WriterState; use server::WriterState;
use server::shared::SharedBytes;
use server::encoding::PayloadStream; use server::encoding::PayloadStream;
use super::{ClientRequest, ClientResponse}; use server::shared::SharedBytes;
use super::{Connect, Connection, ClientConnector, ClientConnectorError};
use super::HttpClientWriter;
use super::{HttpResponseParser, HttpResponseParserError};
/// A set of errors that can occur during request sending and response reading /// A set of errors that can occur during request sending and response reading
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
@ -29,13 +29,13 @@ pub enum SendRequestError {
#[fail(display = "Timeout while waiting for response")] #[fail(display = "Timeout while waiting for response")]
Timeout, Timeout,
/// Failed to connect to host /// Failed to connect to host
#[fail(display="Failed to connect to host: {}", _0)] #[fail(display = "Failed to connect to host: {}", _0)]
Connector(#[cause] ClientConnectorError), Connector(#[cause] ClientConnectorError),
/// Error parsing response /// Error parsing response
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
ParseError(#[cause] HttpResponseParserError), ParseError(#[cause] HttpResponseParserError),
/// Error reading response payload /// Error reading response payload
#[fail(display="Error reading response payload: {}", _0)] #[fail(display = "Error reading response payload: {}", _0)]
Io(#[cause] io::Error), Io(#[cause] io::Error),
} }
@ -79,25 +79,27 @@ impl SendRequest {
SendRequest::with_connector(req, ClientConnector::from_registry()) SendRequest::with_connector(req, ClientConnector::from_registry())
} }
pub(crate) fn with_connector(req: ClientRequest, conn: Addr<Unsync, ClientConnector>) pub(crate) fn with_connector(
-> SendRequest req: ClientRequest, conn: Addr<Unsync, ClientConnector>
{ ) -> SendRequest {
SendRequest{req, conn, SendRequest {
state: State::New, req,
timeout: None, conn,
wait_timeout: Duration::from_secs(5), state: State::New,
conn_timeout: Duration::from_secs(1), timeout: None,
wait_timeout: Duration::from_secs(5),
conn_timeout: Duration::from_secs(1),
} }
} }
pub(crate) fn with_connection(req: ClientRequest, conn: Connection) -> SendRequest pub(crate) fn with_connection(req: ClientRequest, conn: Connection) -> SendRequest {
{ SendRequest {
SendRequest{req, req,
state: State::Connection(conn), state: State::Connection(conn),
conn: ClientConnector::from_registry(), conn: ClientConnector::from_registry(),
timeout: None, timeout: None,
wait_timeout: Duration::from_secs(5), wait_timeout: Duration::from_secs(5),
conn_timeout: Duration::from_secs(1), conn_timeout: Duration::from_secs(1),
} }
} }
@ -139,25 +141,27 @@ impl Future for SendRequest {
let state = mem::replace(&mut self.state, State::None); let state = mem::replace(&mut self.state, State::None);
match state { match state {
State::New => State::New => {
self.state = State::Connect(self.conn.send(Connect { self.state = State::Connect(self.conn.send(Connect {
uri: self.req.uri().clone(), uri: self.req.uri().clone(),
wait_timeout: self.wait_timeout, wait_timeout: self.wait_timeout,
conn_timeout: self.conn_timeout, conn_timeout: self.conn_timeout,
})), }))
}
State::Connect(mut conn) => match conn.poll() { State::Connect(mut conn) => match conn.poll() {
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.state = State::Connect(conn); self.state = State::Connect(conn);
return Ok(Async::NotReady); return Ok(Async::NotReady);
}, }
Ok(Async::Ready(result)) => match result { Ok(Async::Ready(result)) => match result {
Ok(stream) => { Ok(stream) => self.state = State::Connection(stream),
self.state = State::Connection(stream)
},
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}, },
Err(_) => return Err(SendRequestError::Connector( Err(_) => {
ClientConnectorError::Disconnected)) return Err(SendRequestError::Connector(
ClientConnectorError::Disconnected,
))
}
}, },
State::Connection(conn) => { State::Connection(conn) => {
let mut writer = HttpClientWriter::new(SharedBytes::default()); let mut writer = HttpClientWriter::new(SharedBytes::default());
@ -169,12 +173,13 @@ impl Future for SendRequest {
_ => IoBody::Done, _ => IoBody::Done,
}; };
let timeout = self.timeout.take().unwrap_or_else(|| let timeout = self.timeout.take().unwrap_or_else(|| {
Timeout::new( Timeout::new(Duration::from_secs(5), Arbiter::handle()).unwrap()
Duration::from_secs(5), Arbiter::handle()).unwrap()); });
let pl = Box::new(Pipeline { let pl = Box::new(Pipeline {
body, writer, body,
writer,
conn: Some(conn), conn: Some(conn),
parser: Some(HttpResponseParser::default()), parser: Some(HttpResponseParser::default()),
parser_buf: BytesMut::new(), parser_buf: BytesMut::new(),
@ -186,22 +191,22 @@ impl Future for SendRequest {
timeout: Some(timeout), timeout: Some(timeout),
}); });
self.state = State::Send(pl); self.state = State::Send(pl);
}, }
State::Send(mut pl) => { State::Send(mut pl) => {
pl.poll_write() pl.poll_write().map_err(|e| {
.map_err(|e| io::Error::new( io::Error::new(io::ErrorKind::Other, format!("{}", e).as_str())
io::ErrorKind::Other, format!("{}", e).as_str()))?; })?;
match pl.parse() { match pl.parse() {
Ok(Async::Ready(mut resp)) => { Ok(Async::Ready(mut resp)) => {
resp.set_pipeline(pl); resp.set_pipeline(pl);
return Ok(Async::Ready(resp)) return Ok(Async::Ready(resp));
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.state = State::Send(pl); self.state = State::Send(pl);
return Ok(Async::NotReady) return Ok(Async::NotReady);
}, }
Err(err) => return Err(SendRequestError::ParseError(err)) Err(err) => return Err(SendRequestError::ParseError(err)),
} }
} }
State::None => unreachable!(), State::None => unreachable!(),
@ -210,7 +215,6 @@ impl Future for SendRequest {
} }
} }
pub(crate) struct Pipeline { pub(crate) struct Pipeline {
body: IoBody, body: IoBody,
conn: Option<Connection>, conn: Option<Connection>,
@ -254,7 +258,6 @@ impl RunningState {
} }
impl Pipeline { impl Pipeline {
fn release_conn(&mut self) { fn release_conn(&mut self) {
if let Some(conn) = self.conn.take() { if let Some(conn) = self.conn.take() {
conn.release() conn.release()
@ -264,15 +267,22 @@ impl Pipeline {
#[inline] #[inline]
fn parse(&mut self) -> Poll<ClientResponse, HttpResponseParserError> { fn parse(&mut self) -> Poll<ClientResponse, HttpResponseParserError> {
if let Some(ref mut conn) = self.conn { if let Some(ref mut conn) = self.conn {
match self.parser.as_mut().unwrap().parse(conn, &mut self.parser_buf) { match self.parser
.as_mut()
.unwrap()
.parse(conn, &mut self.parser_buf)
{
Ok(Async::Ready(resp)) => { Ok(Async::Ready(resp)) => {
// check content-encoding // check content-encoding
if self.should_decompress { if self.should_decompress {
if let Some(enc) = resp.headers().get(CONTENT_ENCODING) { if let Some(enc) = resp.headers().get(CONTENT_ENCODING) {
if let Ok(enc) = enc.to_str() { if let Ok(enc) = enc.to_str() {
match ContentEncoding::from(enc) { match ContentEncoding::from(enc) {
ContentEncoding::Auto | ContentEncoding::Identity => (), ContentEncoding::Auto
enc => self.decompress = Some(PayloadStream::new(enc)), | ContentEncoding::Identity => (),
enc => {
self.decompress = Some(PayloadStream::new(enc))
}
} }
} }
} }
@ -290,9 +300,10 @@ impl Pipeline {
#[inline] #[inline]
pub fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> { pub fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> {
if self.conn.is_none() { if self.conn.is_none() {
return Ok(Async::Ready(None)) return Ok(Async::Ready(None));
} }
let conn: &mut Connection = unsafe{ mem::transmute(self.conn.as_mut().unwrap())}; let conn: &mut Connection =
unsafe { mem::transmute(self.conn.as_mut().unwrap()) };
let mut need_run = false; let mut need_run = false;
@ -302,15 +313,18 @@ impl Pipeline {
{ {
Async::NotReady => need_run = true, Async::NotReady => need_run = true,
Async::Ready(_) => { Async::Ready(_) => {
let _ = self.poll_timeout() let _ = self.poll_timeout().map_err(|e| {
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?; io::Error::new(io::ErrorKind::Other, format!("{}", e))
})?;
} }
} }
// need read? // need read?
if self.parser.is_some() { if self.parser.is_some() {
loop { loop {
match self.parser.as_mut().unwrap() match self.parser
.as_mut()
.unwrap()
.parse_payload(conn, &mut self.parser_buf)? .parse_payload(conn, &mut self.parser_buf)?
{ {
Async::Ready(Some(b)) => { Async::Ready(Some(b)) => {
@ -318,17 +332,20 @@ impl Pipeline {
match decompress.feed_data(b) { match decompress.feed_data(b) {
Ok(Some(b)) => return Ok(Async::Ready(Some(b))), Ok(Some(b)) => return Ok(Async::Ready(Some(b))),
Ok(None) => return Ok(Async::NotReady), Ok(None) => return Ok(Async::NotReady),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Err(ref err)
continue, if err.kind() == io::ErrorKind::WouldBlock =>
{
continue
}
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
} }
} else { } else {
return Ok(Async::Ready(Some(b))) return Ok(Async::Ready(Some(b)));
} }
}, }
Async::Ready(None) => { Async::Ready(None) => {
let _ = self.parser.take(); let _ = self.parser.take();
break break;
} }
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
} }
@ -340,7 +357,7 @@ impl Pipeline {
let res = decompress.feed_eof(); let res = decompress.feed_eof();
if let Some(b) = res? { if let Some(b) = res? {
self.release_conn(); self.release_conn();
return Ok(Async::Ready(Some(b))) return Ok(Async::Ready(Some(b)));
} }
} }
@ -357,7 +374,7 @@ impl Pipeline {
match self.timeout.as_mut().unwrap().poll() { match self.timeout.as_mut().unwrap().poll() {
Ok(Async::Ready(())) => Err(SendRequestError::Timeout), Ok(Async::Ready(())) => Err(SendRequestError::Timeout),
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => unreachable!() Err(_) => unreachable!(),
} }
} else { } else {
Ok(Async::NotReady) Ok(Async::NotReady)
@ -367,29 +384,27 @@ impl Pipeline {
#[inline] #[inline]
fn poll_write(&mut self) -> Poll<(), Error> { fn poll_write(&mut self) -> Poll<(), Error> {
if self.write_state == RunningState::Done || self.conn.is_none() { if self.write_state == RunningState::Done || self.conn.is_none() {
return Ok(Async::Ready(())) return Ok(Async::Ready(()));
} }
let mut done = false; let mut done = false;
if self.drain.is_none() && self.write_state != RunningState::Paused { if self.drain.is_none() && self.write_state != RunningState::Paused {
'outter: loop { 'outter: loop {
let result = match mem::replace(&mut self.body, IoBody::Done) { let result = match mem::replace(&mut self.body, IoBody::Done) {
IoBody::Payload(mut body) => { IoBody::Payload(mut body) => match body.poll()? {
match body.poll()? { Async::Ready(None) => {
Async::Ready(None) => { self.writer.write_eof()?;
self.writer.write_eof()?; self.disconnected = true;
self.disconnected = true; break;
break }
}, Async::Ready(Some(chunk)) => {
Async::Ready(Some(chunk)) => { self.body = IoBody::Payload(body);
self.body = IoBody::Payload(body); self.writer.write(chunk.into())?
self.writer.write(chunk.into())? }
} Async::NotReady => {
Async::NotReady => { done = true;
done = true; self.body = IoBody::Payload(body);
self.body = IoBody::Payload(body); break;
break
},
} }
}, },
IoBody::Actor(mut ctx) => { IoBody::Actor(mut ctx) => {
@ -400,7 +415,7 @@ impl Pipeline {
Async::Ready(Some(vec)) => { Async::Ready(Some(vec)) => {
if vec.is_empty() { if vec.is_empty() {
self.body = IoBody::Actor(ctx); self.body = IoBody::Actor(ctx);
break break;
} }
let mut res = None; let mut res = None;
for frame in vec { for frame in vec {
@ -409,52 +424,53 @@ impl Pipeline {
// info.context = Some(ctx); // info.context = Some(ctx);
self.disconnected = true; self.disconnected = true;
self.writer.write_eof()?; self.writer.write_eof()?;
break 'outter break 'outter;
}, }
Frame::Chunk(Some(chunk)) => Frame::Chunk(Some(chunk)) => {
res = Some(self.writer.write(chunk)?), res = Some(self.writer.write(chunk)?)
}
Frame::Drain(fut) => self.drain = Some(fut), Frame::Drain(fut) => self.drain = Some(fut),
} }
} }
self.body = IoBody::Actor(ctx); self.body = IoBody::Actor(ctx);
if self.drain.is_some() { if self.drain.is_some() {
self.write_state.resume(); self.write_state.resume();
break break;
} }
res.unwrap() res.unwrap()
}, }
Async::Ready(None) => { Async::Ready(None) => {
done = true; done = true;
break break;
} }
Async::NotReady => { Async::NotReady => {
done = true; done = true;
self.body = IoBody::Actor(ctx); self.body = IoBody::Actor(ctx);
break break;
} }
} }
}, }
IoBody::Done => { IoBody::Done => {
self.disconnected = true; self.disconnected = true;
done = true; done = true;
break break;
} }
}; };
match result { match result {
WriterState::Pause => { WriterState::Pause => {
self.write_state.pause(); self.write_state.pause();
break break;
} }
WriterState::Done => { WriterState::Done => self.write_state.resume(),
self.write_state.resume()
},
} }
} }
} }
// flush io but only if we need to // flush io but only if we need to
match self.writer.poll_completed(self.conn.as_mut().unwrap(), false) { match self.writer
.poll_completed(self.conn.as_mut().unwrap(), false)
{
Ok(Async::Ready(_)) => { Ok(Async::Ready(_)) => {
if self.disconnected { if self.disconnected {
self.write_state = RunningState::Done; self.write_state = RunningState::Done;
@ -472,7 +488,7 @@ impl Pipeline {
} else { } else {
Ok(Async::NotReady) Ok(Async::NotReady)
} }
}, }
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => Err(err.into()), Err(err) => Err(err.into()),
} }

View File

@ -1,26 +1,26 @@
use std::{fmt, mem};
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::Write; use std::io::Write;
use std::time::Duration; use std::time::Duration;
use std::{fmt, mem};
use actix::{Addr, Unsync}; use actix::{Addr, Unsync};
use bytes::{BufMut, Bytes, BytesMut};
use cookie::{Cookie, CookieJar}; use cookie::{Cookie, CookieJar};
use bytes::{Bytes, BytesMut, BufMut};
use futures::Stream; use futures::Stream;
use serde_json; use percent_encoding::{percent_encode, USERINFO_ENCODE_SET};
use serde::Serialize; use serde::Serialize;
use serde_json;
use url::Url; use url::Url;
use percent_encoding::{USERINFO_ENCODE_SET, percent_encode};
use super::connector::{ClientConnector, Connection};
use super::pipeline::SendRequest;
use body::Body; use body::Body;
use error::Error; use error::Error;
use header::{ContentEncoding, Header, IntoHeaderValue}; use header::{ContentEncoding, Header, IntoHeaderValue};
use http::header::{self, HeaderName, HeaderValue};
use http::{uri, Error as HttpError, HeaderMap, HttpTryFrom, Method, Uri, Version};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use http::{uri, HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use super::pipeline::SendRequest;
use super::connector::{Connection, ClientConnector};
/// An HTTP Client Request /// An HTTP Client Request
/// ///
@ -72,7 +72,6 @@ enum ConnectionType {
} }
impl Default for ClientRequest { impl Default for ClientRequest {
fn default() -> ClientRequest { fn default() -> ClientRequest {
ClientRequest { ClientRequest {
uri: Uri::default(), uri: Uri::default(),
@ -92,7 +91,6 @@ impl Default for ClientRequest {
} }
impl ClientRequest { impl ClientRequest {
/// Create request builder for `GET` request /// Create request builder for `GET` request
pub fn get<U: AsRef<str>>(uri: U) -> ClientRequestBuilder { pub fn get<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build(); let mut builder = ClientRequest::build();
@ -130,14 +128,13 @@ impl ClientRequest {
} }
impl ClientRequest { impl ClientRequest {
/// Create client request builder /// Create client request builder
pub fn build() -> ClientRequestBuilder { pub fn build() -> ClientRequestBuilder {
ClientRequestBuilder { ClientRequestBuilder {
request: Some(ClientRequest::default()), request: Some(ClientRequest::default()),
err: None, err: None,
cookies: None, cookies: None,
default_headers: true default_headers: true,
} }
} }
@ -259,8 +256,11 @@ impl ClientRequest {
impl fmt::Debug for ClientRequest { impl fmt::Debug for ClientRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(f, "\nClientRequest {:?} {}:{}", let res = writeln!(
self.version, self.method, self.uri); f,
"\nClientRequest {:?} {}:{}",
self.version, self.method, self.uri
);
let _ = writeln!(f, " headers:"); let _ = writeln!(f, " headers:");
for (key, val) in self.headers.iter() { for (key, val) in self.headers.iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val); let _ = writeln!(f, " {:?}: {:?}", key, val);
@ -277,7 +277,7 @@ pub struct ClientRequestBuilder {
request: Option<ClientRequest>, request: Option<ClientRequest>,
err: Option<HttpError>, err: Option<HttpError>,
cookies: Option<CookieJar>, cookies: Option<CookieJar>,
default_headers: bool default_headers: bool,
} }
impl ClientRequestBuilder { impl ClientRequestBuilder {
@ -300,8 +300,8 @@ impl ClientRequestBuilder {
if let Some(parts) = parts(&mut self.request, &self.err) { if let Some(parts) = parts(&mut self.request, &self.err) {
parts.uri = uri; parts.uri = uri;
} }
}, }
Err(e) => self.err = Some(e.into(),), Err(e) => self.err = Some(e.into()),
} }
self self
} }
@ -318,8 +318,8 @@ impl ClientRequestBuilder {
/// Set HTTP method of this request. /// Set HTTP method of this request.
#[inline] #[inline]
pub fn get_method(&mut self) -> &Method { pub fn get_method(&mut self) -> &Method {
let parts = parts(&mut self.request, &self.err) let parts =
.expect("cannot reuse request builder"); parts(&mut self.request, &self.err).expect("cannot reuse request builder");
&parts.method &parts.method
} }
@ -351,11 +351,12 @@ impl ClientRequestBuilder {
/// } /// }
/// ``` /// ```
#[doc(hidden)] #[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self {
{
if let Some(parts) = parts(&mut self.request, &self.err) { if let Some(parts) = parts(&mut self.request, &self.err) {
match hdr.try_into() { match hdr.try_into() {
Ok(value) => { parts.headers.insert(H::name(), value); } Ok(value) => {
parts.headers.insert(H::name(), value);
}
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
} }
} }
@ -382,15 +383,17 @@ impl ClientRequestBuilder {
/// } /// }
/// ``` /// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{ {
if let Some(parts) = parts(&mut self.request, &self.err) { if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) { match HeaderName::try_from(key) {
Ok(key) => { Ok(key) => match value.try_into() {
match value.try_into() { Ok(value) => {
Ok(value) => { parts.headers.append(key, value); } parts.headers.append(key, value);
Err(e) => self.err = Some(e.into()),
} }
Err(e) => self.err = Some(e.into()),
}, },
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
}; };
@ -400,15 +403,17 @@ impl ClientRequestBuilder {
/// Set a header. /// Set a header.
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{ {
if let Some(parts) = parts(&mut self.request, &self.err) { if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) { match HeaderName::try_from(key) {
Ok(key) => { Ok(key) => match value.try_into() {
match value.try_into() { Ok(value) => {
Ok(value) => { parts.headers.insert(key, value); } parts.headers.insert(key, value);
Err(e) => self.err = Some(e.into()),
} }
Err(e) => self.err = Some(e.into()),
}, },
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
}; };
@ -448,11 +453,14 @@ impl ClientRequestBuilder {
/// Set request's content type /// Set request's content type
#[inline] #[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self pub fn content_type<V>(&mut self, value: V) -> &mut Self
where HeaderValue: HttpTryFrom<V> where
HeaderValue: HttpTryFrom<V>,
{ {
if let Some(parts) = parts(&mut self.request, &self.err) { if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderValue::try_from(value) { match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); }, Ok(value) => {
parts.headers.insert(header::CONTENT_TYPE, value);
}
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
}; };
} }
@ -491,7 +499,10 @@ impl ClientRequestBuilder {
jar.add(cookie.into_owned()); jar.add(cookie.into_owned());
self.cookies = Some(jar) self.cookies = Some(jar)
} else { } else {
self.cookies.as_mut().unwrap().add(cookie.into_owned()); self.cookies
.as_mut()
.unwrap()
.add(cookie.into_owned());
} }
self self
} }
@ -551,7 +562,8 @@ impl ClientRequestBuilder {
/// This method calls provided closure with builder reference if /// This method calls provided closure with builder reference if
/// value is `true`. /// value is `true`.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where F: FnOnce(&mut ClientRequestBuilder) where
F: FnOnce(&mut ClientRequestBuilder),
{ {
if value { if value {
f(self); f(self);
@ -562,7 +574,8 @@ impl ClientRequestBuilder {
/// This method calls provided closure with builder reference if /// This method calls provided closure with builder reference if
/// value is `Some`. /// value is `Some`.
pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self
where F: FnOnce(T, &mut ClientRequestBuilder) where
F: FnOnce(T, &mut ClientRequestBuilder),
{ {
if let Some(val) = value { if let Some(val) = value {
f(val, self); f(val, self);
@ -575,18 +588,20 @@ impl ClientRequestBuilder {
/// `ClientRequestBuilder` can not be used after this call. /// `ClientRequestBuilder` can not be used after this call.
pub fn body<B: Into<Body>>(&mut self, body: B) -> Result<ClientRequest, Error> { pub fn body<B: Into<Body>>(&mut self, body: B) -> Result<ClientRequest, Error> {
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
return Err(e.into()) return Err(e.into());
} }
if self.default_headers { if self.default_headers {
// enable br only for https // enable br only for https
let https = let https = if let Some(parts) = parts(&mut self.request, &self.err) {
if let Some(parts) = parts(&mut self.request, &self.err) { parts
parts.uri.scheme_part() .uri
.map(|s| s == &uri::Scheme::HTTPS).unwrap_or(true) .scheme_part()
} else { .map(|s| s == &uri::Scheme::HTTPS)
true .unwrap_or(true)
}; } else {
true
};
if https { if https {
self.header(header::ACCEPT_ENCODING, "br, gzip, deflate"); self.header(header::ACCEPT_ENCODING, "br, gzip, deflate");
@ -595,7 +610,9 @@ impl ClientRequestBuilder {
} }
} }
let mut request = self.request.take().expect("cannot reuse request builder"); let mut request = self.request
.take()
.expect("cannot reuse request builder");
// set cookies // set cookies
if let Some(ref mut jar) = self.cookies { if let Some(ref mut jar) = self.cookies {
@ -606,7 +623,9 @@ impl ClientRequestBuilder {
let _ = write!(&mut cookie, "; {}={}", name, value); let _ = write!(&mut cookie, "; {}={}", name, value);
} }
request.headers.insert( request.headers.insert(
header::COOKIE, HeaderValue::from_str(&cookie.as_str()[2..]).unwrap()); header::COOKIE,
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
);
} }
request.body = body.into(); request.body = body.into();
Ok(request) Ok(request)
@ -634,10 +653,13 @@ impl ClientRequestBuilder {
/// ///
/// `ClientRequestBuilder` can not be used after this call. /// `ClientRequestBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> Result<ClientRequest, Error> pub fn streaming<S, E>(&mut self, stream: S) -> Result<ClientRequest, Error>
where S: Stream<Item=Bytes, Error=E> + 'static, where
E: Into<Error>, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error>,
{ {
self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into())))) self.body(Body::Streaming(Box::new(
stream.map_err(|e| e.into()),
)))
} }
/// Set an empty body and generate `ClientRequest` /// Set an empty body and generate `ClientRequest`
@ -653,17 +675,17 @@ impl ClientRequestBuilder {
request: self.request.take(), request: self.request.take(),
err: self.err.take(), err: self.err.take(),
cookies: self.cookies.take(), cookies: self.cookies.take(),
default_headers: self.default_headers default_headers: self.default_headers,
} }
} }
} }
#[inline] #[inline]
fn parts<'a>(parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>) fn parts<'a>(
-> Option<&'a mut ClientRequest> parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>
{ ) -> Option<&'a mut ClientRequest> {
if err.is_some() { if err.is_some() {
return None return None;
} }
parts.as_mut() parts.as_mut()
} }
@ -671,8 +693,11 @@ fn parts<'a>(parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>)
impl fmt::Debug for ClientRequestBuilder { impl fmt::Debug for ClientRequestBuilder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(ref parts) = self.request { if let Some(ref parts) = self.request {
let res = writeln!(f, "\nClientRequestBuilder {:?} {}:{}", let res = writeln!(
parts.version, parts.method, parts.uri); f,
"\nClientRequestBuilder {:?} {}:{}",
parts.version, parts.method, parts.uri
);
let _ = writeln!(f, " headers:"); let _ = writeln!(f, " headers:");
for (key, val) in parts.headers.iter() { for (key, val) in parts.headers.iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val); let _ = writeln!(f, " {:?}: {:?}", key, val);

View File

@ -1,19 +1,18 @@
use std::{fmt, str};
use std::rc::Rc;
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::rc::Rc;
use std::{fmt, str};
use bytes::Bytes; use bytes::Bytes;
use cookie::Cookie; use cookie::Cookie;
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use http::{HeaderMap, StatusCode, Version};
use http::header::{self, HeaderValue}; use http::header::{self, HeaderValue};
use http::{HeaderMap, StatusCode, Version};
use httpmessage::HttpMessage;
use error::{CookieParseError, PayloadError}; use error::{CookieParseError, PayloadError};
use httpmessage::HttpMessage;
use super::pipeline::Pipeline; use super::pipeline::Pipeline;
pub(crate) struct ClientMessage { pub(crate) struct ClientMessage {
pub status: StatusCode, pub status: StatusCode,
pub version: Version, pub version: Version,
@ -22,7 +21,6 @@ pub(crate) struct ClientMessage {
} }
impl Default for ClientMessage { impl Default for ClientMessage {
fn default() -> ClientMessage { fn default() -> ClientMessage {
ClientMessage { ClientMessage {
status: StatusCode::OK, status: StatusCode::OK,
@ -45,7 +43,6 @@ impl HttpMessage for ClientResponse {
} }
impl ClientResponse { impl ClientResponse {
pub(crate) fn new(msg: ClientMessage) -> ClientResponse { pub(crate) fn new(msg: ClientMessage) -> ClientResponse {
ClientResponse(Rc::new(UnsafeCell::new(msg)), None) ClientResponse(Rc::new(UnsafeCell::new(msg)), None)
} }
@ -56,13 +53,13 @@ impl ClientResponse {
#[inline] #[inline]
fn as_ref(&self) -> &ClientMessage { fn as_ref(&self) -> &ClientMessage {
unsafe{ &*self.0.get() } unsafe { &*self.0.get() }
} }
#[inline] #[inline]
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))]
fn as_mut(&self) -> &mut ClientMessage { fn as_mut(&self) -> &mut ClientMessage {
unsafe{ &mut *self.0.get() } unsafe { &mut *self.0.get() }
} }
/// Get the HTTP version of this response. /// Get the HTTP version of this response.
@ -96,7 +93,7 @@ impl ClientResponse {
if let Ok(cookies) = self.cookies() { if let Ok(cookies) = self.cookies() {
for cookie in cookies { for cookie in cookies {
if cookie.name() == name { if cookie.name() == name {
return Some(cookie) return Some(cookie);
} }
} }
} }
@ -107,7 +104,11 @@ impl ClientResponse {
impl fmt::Debug for ClientResponse { impl fmt::Debug for ClientResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!( let res = writeln!(
f, "\nClientResponse {:?} {}", self.version(), self.status()); f,
"\nClientResponse {:?} {}",
self.version(),
self.status()
);
let _ = writeln!(f, " headers:"); let _ = writeln!(f, " headers:");
for (key, val) in self.headers().iter() { for (key, val) in self.headers().iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val); let _ = writeln!(f, " {:?}: {:?}", key, val);
@ -138,9 +139,13 @@ mod tests {
fn test_debug() { fn test_debug() {
let resp = ClientResponse::new(ClientMessage::default()); let resp = ClientResponse::new(ClientMessage::default());
resp.as_mut().headers.insert( resp.as_mut().headers.insert(
header::COOKIE, HeaderValue::from_static("cookie1=value1")); header::COOKIE,
HeaderValue::from_static("cookie1=value1"),
);
resp.as_mut().headers.insert( resp.as_mut().headers.insert(
header::COOKIE, HeaderValue::from_static("cookie2=value2")); header::COOKIE,
HeaderValue::from_static("cookie2=value2"),
);
let dbg = format!("{:?}", resp); let dbg = format!("{:?}", resp);
assert!(dbg.contains("ClientResponse")); assert!(dbg.contains("ClientResponse"));

View File

@ -1,30 +1,29 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] #![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::io::{self, Write};
use std::cell::RefCell; use std::cell::RefCell;
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::{self, Write};
use time::{self, Duration}; #[cfg(feature = "brotli")]
use bytes::{BytesMut, BufMut};
use futures::{Async, Poll};
use tokio_io::AsyncWrite;
use http::{Version, HttpTryFrom};
use http::header::{HeaderValue, DATE,
CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use flate2::Compression;
use flate2::write::{GzEncoder, DeflateEncoder};
#[cfg(feature="brotli")]
use brotli2::write::BrotliEncoder; use brotli2::write::BrotliEncoder;
use bytes::{BufMut, BytesMut};
use flate2::Compression;
use flate2::write::{DeflateEncoder, GzEncoder};
use futures::{Async, Poll};
use http::header::{HeaderValue, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, DATE,
TRANSFER_ENCODING};
use http::{HttpTryFrom, Version};
use time::{self, Duration};
use tokio_io::AsyncWrite;
use body::{Body, Binary}; use body::{Binary, Body};
use header::ContentEncoding; use header::ContentEncoding;
use server::WriterState; use server::WriterState;
use server::shared::SharedBytes;
use server::encoding::{ContentEncoder, TransferEncoding}; use server::encoding::{ContentEncoder, TransferEncoding};
use server::shared::SharedBytes;
use client::ClientRequest; use client::ClientRequest;
const AVERAGE_HEADER_SIZE: usize = 30; const AVERAGE_HEADER_SIZE: usize = 30;
bitflags! { bitflags! {
@ -46,7 +45,6 @@ pub(crate) struct HttpClientWriter {
} }
impl HttpClientWriter { impl HttpClientWriter {
pub fn new(buffer: SharedBytes) -> HttpClientWriter { pub fn new(buffer: SharedBytes) -> HttpClientWriter {
let encoder = ContentEncoder::Identity(TransferEncoding::eof(buffer.clone())); let encoder = ContentEncoder::Identity(TransferEncoding::eof(buffer.clone()));
HttpClientWriter { HttpClientWriter {
@ -64,24 +62,26 @@ impl HttpClientWriter {
} }
// pub fn keepalive(&self) -> bool { // pub fn keepalive(&self) -> bool {
// self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) // self.flags.contains(Flags::KEEPALIVE) &&
// } // !self.flags.contains(Flags::UPGRADE) }
fn write_to_stream<T: AsyncWrite>(&mut self, stream: &mut T) -> io::Result<WriterState> { fn write_to_stream<T: AsyncWrite>(
&mut self, stream: &mut T
) -> io::Result<WriterState> {
while !self.buffer.is_empty() { while !self.buffer.is_empty() {
match stream.write(self.buffer.as_ref()) { match stream.write(self.buffer.as_ref()) {
Ok(0) => { Ok(0) => {
self.disconnected(); self.disconnected();
return Ok(WriterState::Done); return Ok(WriterState::Done);
}, }
Ok(n) => { Ok(n) => {
let _ = self.buffer.split_to(n); let _ = self.buffer.split_to(n);
}, }
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if self.buffer.len() > self.buffer_capacity { if self.buffer.len() > self.buffer_capacity {
return Ok(WriterState::Pause) return Ok(WriterState::Pause);
} else { } else {
return Ok(WriterState::Done) return Ok(WriterState::Done);
} }
} }
Err(err) => return Err(err), Err(err) => return Err(err),
@ -92,7 +92,6 @@ impl HttpClientWriter {
} }
impl HttpClientWriter { impl HttpClientWriter {
pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> { pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> {
// prepare task // prepare task
self.flags.insert(Flags::STARTED); self.flags.insert(Flags::STARTED);
@ -105,10 +104,16 @@ impl HttpClientWriter {
// render message // render message
{ {
// status line // status line
writeln!(self.buffer, "{} {} {:?}\r", writeln!(
msg.method(), self.buffer,
msg.uri().path_and_query().map(|u| u.as_str()).unwrap_or("/"), "{} {} {:?}\r",
msg.version())?; msg.method(),
msg.uri()
.path_and_query()
.map(|u| u.as_str())
.unwrap_or("/"),
msg.version()
)?;
// write headers // write headers
let mut buffer = self.buffer.get_mut(); let mut buffer = self.buffer.get_mut();
@ -173,15 +178,17 @@ impl HttpClientWriter {
if self.encoder.is_eof() { if self.encoder.is_eof() {
Ok(()) Ok(())
} else { } else {
Err(io::Error::new(io::ErrorKind::Other, Err(io::Error::new(
"Last payload item, but eof is not reached")) io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
} }
} }
#[inline] #[inline]
pub fn poll_completed<T: AsyncWrite>(&mut self, stream: &mut T, shutdown: bool) pub fn poll_completed<T: AsyncWrite>(
-> Poll<(), io::Error> &mut self, stream: &mut T, shutdown: bool
{ ) -> Poll<(), io::Error> {
match self.write_to_stream(stream) { match self.write_to_stream(stream) {
Ok(WriterState::Done) => { Ok(WriterState::Done) => {
if shutdown { if shutdown {
@ -189,14 +196,13 @@ impl HttpClientWriter {
} else { } else {
Ok(Async::Ready(())) Ok(Async::Ready(()))
} }
}, }
Ok(WriterState::Pause) => Ok(Async::NotReady), Ok(WriterState::Pause) => Ok(Async::NotReady),
Err(err) => Err(err) Err(err) => Err(err),
} }
} }
} }
fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder { fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder {
let version = req.version(); let version = req.version();
let mut body = req.replace_body(Body::Empty); let mut body = req.replace_body(Body::Empty);
@ -206,21 +212,25 @@ fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder
Body::Empty => { Body::Empty => {
req.headers_mut().remove(CONTENT_LENGTH); req.headers_mut().remove(CONTENT_LENGTH);
TransferEncoding::length(0, buf) TransferEncoding::length(0, buf)
}, }
Body::Binary(ref mut bytes) => { Body::Binary(ref mut bytes) => {
if encoding.is_compression() { if encoding.is_compression() {
let tmp = SharedBytes::default(); let tmp = SharedBytes::default();
let transfer = TransferEncoding::eof(tmp.clone()); let transfer = TransferEncoding::eof(tmp.clone());
let mut enc = match encoding { let mut enc = match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate( ContentEncoding::Deflate => ContentEncoder::Deflate(
DeflateEncoder::new(transfer, Compression::default())), DeflateEncoder::new(transfer, Compression::default()),
ContentEncoding::Gzip => ContentEncoder::Gzip( ),
GzEncoder::new(transfer, Compression::default())), ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new(
#[cfg(feature="brotli")] transfer,
ContentEncoding::Br => ContentEncoder::Br( Compression::default(),
BrotliEncoder::new(transfer, 5)), )),
#[cfg(feature = "brotli")]
ContentEncoding::Br => {
ContentEncoder::Br(BrotliEncoder::new(transfer, 5))
}
ContentEncoding::Identity => ContentEncoder::Identity(transfer), ContentEncoding::Identity => ContentEncoder::Identity(transfer),
ContentEncoding::Auto => unreachable!() ContentEncoding::Auto => unreachable!(),
}; };
// TODO return error! // TODO return error!
let _ = enc.write(bytes.clone()); let _ = enc.write(bytes.clone());
@ -228,21 +238,26 @@ fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder
*bytes = Binary::from(tmp.take()); *bytes = Binary::from(tmp.take());
req.headers_mut().insert( req.headers_mut().insert(
CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
encoding = ContentEncoding::Identity; encoding = ContentEncoding::Identity;
} }
let mut b = BytesMut::new(); let mut b = BytesMut::new();
let _ = write!(b, "{}", bytes.len()); let _ = write!(b, "{}", bytes.len());
req.headers_mut().insert( req.headers_mut().insert(
CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); CONTENT_LENGTH,
HeaderValue::try_from(b.freeze()).unwrap(),
);
TransferEncoding::eof(buf) TransferEncoding::eof(buf)
}, }
Body::Streaming(_) | Body::Actor(_) => { Body::Streaming(_) | Body::Actor(_) => {
if req.upgrade() { if req.upgrade() {
if version == Version::HTTP_2 { if version == Version::HTTP_2 {
error!("Connection upgrade is forbidden for HTTP/2"); error!("Connection upgrade is forbidden for HTTP/2");
} else { } else {
req.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); req.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
} }
if encoding != ContentEncoding::Identity { if encoding != ContentEncoding::Identity {
encoding = ContentEncoding::Identity; encoding = ContentEncoding::Identity;
@ -257,24 +272,31 @@ fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder
if encoding.is_compression() { if encoding.is_compression() {
req.headers_mut().insert( req.headers_mut().insert(
CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
} }
req.replace_body(body); req.replace_body(body);
match encoding { match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate( ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
DeflateEncoder::new(transfer, Compression::default())), transfer,
ContentEncoding::Gzip => ContentEncoder::Gzip( Compression::default(),
GzEncoder::new(transfer, Compression::default())), )),
#[cfg(feature="brotli")] ContentEncoding::Gzip => {
ContentEncoding::Br => ContentEncoder::Br( ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::default()))
BrotliEncoder::new(transfer, 5)), }
ContentEncoding::Identity | ContentEncoding::Auto => ContentEncoder::Identity(transfer), #[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
}
} }
} }
fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientRequest) fn streaming_encoding(
-> TransferEncoding { buf: SharedBytes, version: Version, req: &mut ClientRequest
) -> TransferEncoding {
if req.chunked() { if req.chunked() {
// Enable transfer encoding // Enable transfer encoding
req.headers_mut().remove(CONTENT_LENGTH); req.headers_mut().remove(CONTENT_LENGTH);
@ -282,29 +304,28 @@ fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientReques
req.headers_mut().remove(TRANSFER_ENCODING); req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf) TransferEncoding::eof(buf)
} else { } else {
req.headers_mut().insert( req.headers_mut()
TRANSFER_ENCODING, HeaderValue::from_static("chunked")); .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked(buf) TransferEncoding::chunked(buf)
} }
} else { } else {
// if Content-Length is specified, then use it as length hint // if Content-Length is specified, then use it as length hint
let (len, chunked) = let (len, chunked) = if let Some(len) = req.headers().get(CONTENT_LENGTH) {
if let Some(len) = req.headers().get(CONTENT_LENGTH) { // Content-Length
// Content-Length if let Ok(s) = len.to_str() {
if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::<u64>() {
if let Ok(len) = s.parse::<u64>() { (Some(len), false)
(Some(len), false)
} else {
error!("illegal Content-Length: {:?}", len);
(None, false)
}
} else { } else {
error!("illegal Content-Length: {:?}", len); error!("illegal Content-Length: {:?}", len);
(None, false) (None, false)
} }
} else { } else {
(None, true) error!("illegal Content-Length: {:?}", len);
}; (None, false)
}
} else {
(None, true)
};
if !chunked { if !chunked {
if let Some(len) = len { if let Some(len) = len {
@ -316,10 +337,10 @@ fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientReques
// Enable transfer encoding // Enable transfer encoding
match version { match version {
Version::HTTP_11 => { Version::HTTP_11 => {
req.headers_mut().insert( req.headers_mut()
TRANSFER_ENCODING, HeaderValue::from_static("chunked")); .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked(buf) TransferEncoding::chunked(buf)
}, }
_ => { _ => {
req.headers_mut().remove(TRANSFER_ENCODING); req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf) TransferEncoding::eof(buf)
@ -329,7 +350,6 @@ fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientReques
} }
} }
// "Sun, 06 Nov 1994 08:49:37 GMT".len() // "Sun, 06 Nov 1994 08:49:37 GMT".len()
pub const DATE_VALUE_LENGTH: usize = 29; pub const DATE_VALUE_LENGTH: usize = 29;

View File

@ -1,20 +1,19 @@
use std::mem;
use std::marker::PhantomData;
use futures::{Async, Future, Poll};
use futures::sync::oneshot::Sender; use futures::sync::oneshot::Sender;
use futures::unsync::oneshot; use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::marker::PhantomData;
use std::mem;
use actix::{Actor, ActorState, ActorContext, AsyncContext, use actix::dev::{ContextImpl, SyncEnvelope, ToEnvelope};
Addr, Handler, Message, SpawnHandle, Syn, Unsync};
use actix::fut::ActorFuture; use actix::fut::ActorFuture;
use actix::dev::{ContextImpl, ToEnvelope, SyncEnvelope}; use actix::{Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message,
SpawnHandle, Syn, Unsync};
use body::{Body, Binary}; use body::{Binary, Body};
use error::{Error, ErrorInternalServerError}; use error::{Error, ErrorInternalServerError};
use httprequest::HttpRequest; use httprequest::HttpRequest;
pub trait ActorHttpContext: 'static { pub trait ActorHttpContext: 'static {
fn disconnected(&mut self); fn disconnected(&mut self);
fn poll(&mut self) -> Poll<Option<SmallVec<[Frame; 4]>>, Error>; fn poll(&mut self) -> Poll<Option<SmallVec<[Frame; 4]>>, Error>;
@ -36,7 +35,9 @@ impl Frame {
} }
/// Execution context for http actors /// Execution context for http actors
pub struct HttpContext<A, S=()> where A: Actor<Context=HttpContext<A, S>>, pub struct HttpContext<A, S = ()>
where
A: Actor<Context = HttpContext<A, S>>,
{ {
inner: ContextImpl<A>, inner: ContextImpl<A>,
stream: Option<SmallVec<[Frame; 4]>>, stream: Option<SmallVec<[Frame; 4]>>,
@ -44,7 +45,9 @@ pub struct HttpContext<A, S=()> where A: Actor<Context=HttpContext<A, S>>,
disconnected: bool, disconnected: bool,
} }
impl<A, S> ActorContext for HttpContext<A, S> where A: Actor<Context=Self> impl<A, S> ActorContext for HttpContext<A, S>
where
A: Actor<Context = Self>,
{ {
fn stop(&mut self) { fn stop(&mut self) {
self.inner.stop(); self.inner.stop();
@ -57,25 +60,29 @@ impl<A, S> ActorContext for HttpContext<A, S> where A: Actor<Context=Self>
} }
} }
impl<A, S> AsyncContext<A> for HttpContext<A, S> where A: Actor<Context=Self> impl<A, S> AsyncContext<A> for HttpContext<A, S>
where
A: Actor<Context = Self>,
{ {
#[inline] #[inline]
fn spawn<F>(&mut self, fut: F) -> SpawnHandle fn spawn<F>(&mut self, fut: F) -> SpawnHandle
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{ {
self.inner.spawn(fut) self.inner.spawn(fut)
} }
#[inline] #[inline]
fn wait<F>(&mut self, fut: F) fn wait<F>(&mut self, fut: F)
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{ {
self.inner.wait(fut) self.inner.wait(fut)
} }
#[doc(hidden)] #[doc(hidden)]
#[inline] #[inline]
fn waiting(&self) -> bool { fn waiting(&self) -> bool {
self.inner.waiting() || self.inner.state() == ActorState::Stopping || self.inner.waiting() || self.inner.state() == ActorState::Stopping
self.inner.state() == ActorState::Stopped || self.inner.state() == ActorState::Stopped
} }
#[inline] #[inline]
fn cancel_future(&mut self, handle: SpawnHandle) -> bool { fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
@ -93,8 +100,10 @@ impl<A, S> AsyncContext<A> for HttpContext<A, S> where A: Actor<Context=Self>
} }
} }
impl<A, S: 'static> HttpContext<A, S> where A: Actor<Context=Self> { impl<A, S: 'static> HttpContext<A, S>
where
A: Actor<Context = Self>,
{
#[inline] #[inline]
pub fn new(req: HttpRequest<S>, actor: A) -> HttpContext<A, S> { pub fn new(req: HttpRequest<S>, actor: A) -> HttpContext<A, S> {
HttpContext::from_request(req).actor(actor) HttpContext::from_request(req).actor(actor)
@ -114,8 +123,10 @@ impl<A, S: 'static> HttpContext<A, S> where A: Actor<Context=Self> {
} }
} }
impl<A, S> HttpContext<A, S> where A: Actor<Context=Self> { impl<A, S> HttpContext<A, S>
where
A: Actor<Context = Self>,
{
/// Shared application state /// Shared application state
#[inline] #[inline]
pub fn state(&self) -> &S { pub fn state(&self) -> &S {
@ -175,8 +186,11 @@ impl<A, S> HttpContext<A, S> where A: Actor<Context=Self> {
} }
} }
impl<A, S> ActorHttpContext for HttpContext<A, S> where A: Actor<Context=Self>, S: 'static { impl<A, S> ActorHttpContext for HttpContext<A, S>
where
A: Actor<Context = Self>,
S: 'static,
{
#[inline] #[inline]
fn disconnected(&mut self) { fn disconnected(&mut self) {
self.disconnected = true; self.disconnected = true;
@ -184,9 +198,8 @@ impl<A, S> ActorHttpContext for HttpContext<A, S> where A: Actor<Context=Self>,
} }
fn poll(&mut self) -> Poll<Option<SmallVec<[Frame; 4]>>, Error> { fn poll(&mut self) -> Poll<Option<SmallVec<[Frame; 4]>>, Error> {
let ctx: &mut HttpContext<A, S> = unsafe { let ctx: &mut HttpContext<A, S> =
mem::transmute(self as &mut HttpContext<A, S>) unsafe { mem::transmute(self as &mut HttpContext<A, S>) };
};
if self.inner.alive() { if self.inner.alive() {
match self.inner.poll(ctx) { match self.inner.poll(ctx) {
@ -207,8 +220,10 @@ impl<A, S> ActorHttpContext for HttpContext<A, S> where A: Actor<Context=Self>,
} }
impl<A, M, S> ToEnvelope<Syn, A, M> for HttpContext<A, S> impl<A, M, S> ToEnvelope<Syn, A, M> for HttpContext<A, S>
where A: Actor<Context=HttpContext<A, S>> + Handler<M>, where
M: Message + Send + 'static, M::Result: Send, A: Actor<Context = HttpContext<A, S>> + Handler<M>,
M: Message + Send + 'static,
M::Result: Send,
{ {
fn pack(msg: M, tx: Option<Sender<M::Result>>) -> SyncEnvelope<A> { fn pack(msg: M, tx: Option<Sender<M::Result>>) -> SyncEnvelope<A> {
SyncEnvelope::new(msg, tx) SyncEnvelope::new(msg, tx)
@ -216,8 +231,9 @@ impl<A, M, S> ToEnvelope<Syn, A, M> for HttpContext<A, S>
} }
impl<A, S> From<HttpContext<A, S>> for Body impl<A, S> From<HttpContext<A, S>> for Body
where A: Actor<Context=HttpContext<A, S>>, where
S: 'static A: Actor<Context = HttpContext<A, S>>,
S: 'static,
{ {
fn from(ctx: HttpContext<A, S>) -> Body { fn from(ctx: HttpContext<A, S>) -> Body {
Body::Actor(Box::new(ctx)) Body::Actor(Box::new(ctx))
@ -231,7 +247,10 @@ pub struct Drain<A> {
impl<A> Drain<A> { impl<A> Drain<A> {
pub fn new(fut: oneshot::Receiver<()>) -> Self { pub fn new(fut: oneshot::Receiver<()>) -> Self {
Drain { fut, _a: PhantomData } Drain {
fut,
_a: PhantomData,
}
} }
} }
@ -241,10 +260,9 @@ impl<A: Actor> ActorFuture for Drain<A> {
type Actor = A; type Actor = A;
#[inline] #[inline]
fn poll(&mut self, fn poll(
_: &mut A, &mut self, _: &mut A, _: &mut <Self::Actor as Actor>::Context
_: &mut <Self::Actor as Actor>::Context) -> Poll<Self::Item, Self::Error> ) -> Poll<Self::Item, Self::Error> {
{
self.fut.poll().map_err(|_| ()) self.fut.poll().map_err(|_| ())
} }
} }

228
src/de.rs
View File

@ -1,11 +1,10 @@
use std::slice::Iter; use serde::de::{self, Deserializer, Error as DeError, Visitor};
use std::borrow::Cow; use std::borrow::Cow;
use std::convert::AsRef; use std::convert::AsRef;
use serde::de::{self, Deserializer, Visitor, Error as DeError}; use std::slice::Iter;
use httprequest::HttpRequest; use httprequest::HttpRequest;
macro_rules! unsupported_type { macro_rules! unsupported_type {
($trait_fn:ident, $name:expr) => { ($trait_fn:ident, $name:expr) => {
fn $trait_fn<V>(self, _: V) -> Result<V::Value, Self::Error> fn $trait_fn<V>(self, _: V) -> Result<V::Value, Self::Error>
@ -37,103 +36,136 @@ macro_rules! parse_single_value {
} }
pub struct PathDeserializer<'de, S: 'de> { pub struct PathDeserializer<'de, S: 'de> {
req: &'de HttpRequest<S> req: &'de HttpRequest<S>,
} }
impl<'de, S: 'de> PathDeserializer<'de, S> { impl<'de, S: 'de> PathDeserializer<'de, S> {
pub fn new(req: &'de HttpRequest<S>) -> Self { pub fn new(req: &'de HttpRequest<S>) -> Self {
PathDeserializer{req} PathDeserializer { req }
} }
} }
impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> {
{
type Error = de::value::Error; type Error = de::value::Error;
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_map(ParamsDeserializer{ visitor.visit_map(ParamsDeserializer {
params: self.req.match_info().iter(), params: self.req.match_info().iter(),
current: None, current: None,
}) })
} }
fn deserialize_struct<V>(self, _: &'static str, _: &'static [&'static str], visitor: V) fn deserialize_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, _: &'static [&'static str], visitor: V
where V: Visitor<'de>, ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
self.deserialize_map(visitor) self.deserialize_map(visitor)
} }
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_unit() visitor.visit_unit()
} }
fn deserialize_unit_struct<V>(self, _: &'static str, visitor: V) fn deserialize_unit_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, visitor: V
where V: Visitor<'de> ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
self.deserialize_unit(visitor) self.deserialize_unit(visitor)
} }
fn deserialize_newtype_struct<V>(self, _: &'static str, visitor: V) fn deserialize_newtype_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, visitor: V
where V: Visitor<'de>, ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
visitor.visit_newtype_struct(self) visitor.visit_newtype_struct(self)
} }
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_tuple<V>(
where V: Visitor<'de> self, len: usize, visitor: V
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
if self.req.match_info().len() < len { if self.req.match_info().len() < len {
Err(de::value::Error::custom( Err(de::value::Error::custom(
format!("wrong number of parameters: {} expected {}", format!(
self.req.match_info().len(), len).as_str())) "wrong number of parameters: {} expected {}",
self.req.match_info().len(),
len
).as_str(),
))
} else { } else {
visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()}) visitor.visit_seq(ParamsSeq {
params: self.req.match_info().iter(),
})
} }
} }
fn deserialize_tuple_struct<V>(self, _: &'static str, len: usize, visitor: V) fn deserialize_tuple_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, len: usize, visitor: V
where V: Visitor<'de> ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
if self.req.match_info().len() < len { if self.req.match_info().len() < len {
Err(de::value::Error::custom( Err(de::value::Error::custom(
format!("wrong number of parameters: {} expected {}", format!(
self.req.match_info().len(), len).as_str())) "wrong number of parameters: {} expected {}",
self.req.match_info().len(),
len
).as_str(),
))
} else { } else {
visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()}) visitor.visit_seq(ParamsSeq {
params: self.req.match_info().iter(),
})
} }
} }
fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], _: V) fn deserialize_enum<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, _: &'static [&'static str], _: V
where V: Visitor<'de> ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("unsupported type: enum")) Err(de::value::Error::custom("unsupported type: enum"))
} }
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
if self.req.match_info().len() != 1 { if self.req.match_info().len() != 1 {
Err(de::value::Error::custom( Err(de::value::Error::custom(
format!("wrong number of parameters: {} expected 1", format!(
self.req.match_info().len()).as_str())) "wrong number of parameters: {} expected 1",
self.req.match_info().len()
).as_str(),
))
} else { } else {
visitor.visit_str(&self.req.match_info()[0]) visitor.visit_str(&self.req.match_info()[0])
} }
} }
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de> where
V: Visitor<'de>,
{ {
visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()}) visitor.visit_seq(ParamsSeq {
params: self.req.match_info().iter(),
})
} }
unsupported_type!(deserialize_any, "'any'"); unsupported_type!(deserialize_any, "'any'");
@ -163,22 +195,25 @@ struct ParamsDeserializer<'de> {
current: Option<(&'de str, &'de str)>, current: Option<(&'de str, &'de str)>,
} }
impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de> impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de> {
{
type Error = de::value::Error; type Error = de::value::Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where K: de::DeserializeSeed<'de>, where
K: de::DeserializeSeed<'de>,
{ {
self.current = self.params.next().map(|&(ref k, ref v)| (k.as_ref(), v.as_ref())); self.current = self.params
.next()
.map(|&(ref k, ref v)| (k.as_ref(), v.as_ref()));
match self.current { match self.current {
Some((key, _)) => Ok(Some(seed.deserialize(Key{key})?)), Some((key, _)) => Ok(Some(seed.deserialize(Key { key })?)),
None => Ok(None), None => Ok(None),
} }
} }
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where V: de::DeserializeSeed<'de>, where
V: de::DeserializeSeed<'de>,
{ {
if let Some((_, value)) = self.current.take() { if let Some((_, value)) = self.current.take() {
seed.deserialize(Value { value }) seed.deserialize(Value { value })
@ -196,13 +231,15 @@ impl<'de> Deserializer<'de> for Key<'de> {
type Error = de::value::Error; type Error = de::value::Error;
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_str(self.key) visitor.visit_str(self.key)
} }
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("Unexpected")) Err(de::value::Error::custom("Unexpected"))
} }
@ -231,8 +268,7 @@ struct Value<'de> {
value: &'de str, value: &'de str,
} }
impl<'de> Deserializer<'de> for Value<'de> impl<'de> Deserializer<'de> for Value<'de> {
{
type Error = de::value::Error; type Error = de::value::Error;
parse_value!(deserialize_bool, visit_bool, "bool"); parse_value!(deserialize_bool, visit_bool, "bool");
@ -251,74 +287,94 @@ impl<'de> Deserializer<'de> for Value<'de>
parse_value!(deserialize_char, visit_char, "char"); parse_value!(deserialize_char, visit_char, "char");
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_unit() visitor.visit_unit()
} }
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_unit() visitor.visit_unit()
} }
fn deserialize_unit_struct<V>( fn deserialize_unit_struct<V>(
self, _: &'static str, visitor: V) -> Result<V::Value, Self::Error> self, _: &'static str, visitor: V
where V: Visitor<'de> ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
visitor.visit_unit() visitor.visit_unit()
} }
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_borrowed_bytes(self.value.as_bytes()) visitor.visit_borrowed_bytes(self.value.as_bytes())
} }
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_borrowed_str(self.value) visitor.visit_borrowed_str(self.value)
} }
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
visitor.visit_some(self) visitor.visit_some(self)
} }
fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], visitor: V) fn deserialize_enum<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, _: &'static [&'static str], visitor: V
where V: Visitor<'de>, ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
visitor.visit_enum(ValueEnum {value: self.value}) visitor.visit_enum(ValueEnum {
value: self.value,
})
} }
fn deserialize_newtype_struct<V>(self, _: &'static str, visitor: V) fn deserialize_newtype_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, visitor: V
where V: Visitor<'de>, ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
visitor.visit_newtype_struct(self) visitor.visit_newtype_struct(self)
} }
fn deserialize_tuple<V>(self, _: usize, _: V) -> Result<V::Value, Self::Error> fn deserialize_tuple<V>(self, _: usize, _: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de> where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("unsupported type: tuple")) Err(de::value::Error::custom("unsupported type: tuple"))
} }
fn deserialize_struct<V>(self, _: &'static str, _: &'static [&'static str], _: V) fn deserialize_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, _: &'static [&'static str], _: V
where V: Visitor<'de> ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("unsupported type: struct")) Err(de::value::Error::custom("unsupported type: struct"))
} }
fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V) fn deserialize_tuple_struct<V>(
-> Result<V::Value, Self::Error> self, _: &'static str, _: usize, _: V
where V: Visitor<'de> ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("unsupported type: tuple struct")) Err(de::value::Error::custom(
"unsupported type: tuple struct",
))
} }
unsupported_type!(deserialize_any, "any"); unsupported_type!(deserialize_any, "any");
@ -331,15 +387,17 @@ struct ParamsSeq<'de> {
params: Iter<'de, (Cow<'de, str>, Cow<'de, str>)>, params: Iter<'de, (Cow<'de, str>, Cow<'de, str>)>,
} }
impl<'de> de::SeqAccess<'de> for ParamsSeq<'de> impl<'de> de::SeqAccess<'de> for ParamsSeq<'de> {
{
type Error = de::value::Error; type Error = de::value::Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error> fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where T: de::DeserializeSeed<'de>, where
T: de::DeserializeSeed<'de>,
{ {
match self.params.next() { match self.params.next() {
Some(item) => Ok(Some(seed.deserialize(Value { value: item.1.as_ref() })?)), Some(item) => Ok(Some(seed.deserialize(Value {
value: item.1.as_ref(),
})?)),
None => Ok(None), None => Ok(None),
} }
} }
@ -354,9 +412,13 @@ impl<'de> de::EnumAccess<'de> for ValueEnum<'de> {
type Variant = UnitVariant; type Variant = UnitVariant;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where V: de::DeserializeSeed<'de>, where
V: de::DeserializeSeed<'de>,
{ {
Ok((seed.deserialize(Key { key: self.value })?, UnitVariant)) Ok((
seed.deserialize(Key { key: self.value })?,
UnitVariant,
))
} }
} }
@ -370,20 +432,24 @@ impl<'de> de::VariantAccess<'de> for UnitVariant {
} }
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Self::Error> fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Self::Error>
where T: de::DeserializeSeed<'de>, where
T: de::DeserializeSeed<'de>,
{ {
Err(de::value::Error::custom("not supported")) Err(de::value::Error::custom("not supported"))
} }
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error> fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
where V: Visitor<'de>, where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("not supported")) Err(de::value::Error::custom("not supported"))
} }
fn struct_variant<V>(self, _: &'static [&'static str], _: V) fn struct_variant<V>(
-> Result<V::Value, Self::Error> self, _: &'static [&'static str], _: V
where V: Visitor<'de>, ) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{ {
Err(de::value::Error::custom("not supported")) Err(de::value::Error::custom("not supported"))
} }

View File

@ -1,24 +1,24 @@
//! Error and Result module //! Error and Result module
use std::{io, fmt, result}; use std::io::Error as IoError;
use std::str::Utf8Error; use std::str::Utf8Error;
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use std::io::Error as IoError; use std::{fmt, io, result};
use cookie;
use httparse;
use actix::MailboxError; use actix::MailboxError;
use cookie;
use failure::{self, Backtrace, Fail};
use futures::Canceled; use futures::Canceled;
use failure::{self, Fail, Backtrace};
use http2::Error as Http2Error;
use http::{header, StatusCode, Error as HttpError};
use http::uri::InvalidUri; use http::uri::InvalidUri;
use http::{header, Error as HttpError, StatusCode};
use http2::Error as Http2Error;
use http_range::HttpRangeParseError; use http_range::HttpRangeParseError;
use httparse;
use serde::de::value::Error as DeError; use serde::de::value::Error as DeError;
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
pub use url::ParseError as UrlParseError; pub use url::ParseError as UrlParseError;
// re-exports // re-exports
pub use cookie::{ParseError as CookieParseError}; pub use cookie::ParseError as CookieParseError;
use handler::Responder; use handler::Responder;
use httprequest::HttpRequest; use httprequest::HttpRequest;
@ -27,9 +27,10 @@ use httpresponse::HttpResponse;
/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) /// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html)
/// for actix web operations /// for actix web operations
/// ///
/// This typedef is generally used to avoid writing out `actix_web::error::Error` directly and /// This typedef is generally used to avoid writing out
/// is otherwise a direct mapping to `Result`. /// `actix_web::error::Error` directly and is otherwise a direct mapping to
pub type Result<T, E=Error> = result::Result<T, E>; /// `Result`.
pub type Result<T, E = Error> = result::Result<T, E>;
/// General purpose actix web error /// General purpose actix web error
pub struct Error { pub struct Error {
@ -38,7 +39,6 @@ pub struct Error {
} }
impl Error { impl Error {
/// Returns a reference to the underlying cause of this Error. /// Returns a reference to the underlying cause of this Error.
// this should return &Fail but needs this https://github.com/rust-lang/rust/issues/5665 // this should return &Fail but needs this https://github.com/rust-lang/rust/issues/5665
pub fn cause(&self) -> &ResponseError { pub fn cause(&self) -> &ResponseError {
@ -48,7 +48,6 @@ impl Error {
/// Error that can be converted to `HttpResponse` /// Error that can be converted to `HttpResponse`
pub trait ResponseError: Fail { pub trait ResponseError: Fail {
/// Create response for error /// Create response for error
/// ///
/// Internal server error is generated by default. /// Internal server error is generated by default.
@ -68,7 +67,12 @@ impl fmt::Debug for Error {
if let Some(bt) = self.cause.backtrace() { if let Some(bt) = self.cause.backtrace() {
write!(f, "{:?}\n\n{:?}", &self.cause, bt) write!(f, "{:?}\n\n{:?}", &self.cause, bt)
} else { } else {
write!(f, "{:?}\n\n{:?}", &self.cause, self.backtrace.as_ref().unwrap()) write!(
f,
"{:?}\n\n{:?}",
&self.cause,
self.backtrace.as_ref().unwrap()
)
} }
} }
} }
@ -88,13 +92,19 @@ impl<T: ResponseError> From<T> for Error {
} else { } else {
None None
}; };
Error { cause: Box::new(err), backtrace } Error {
cause: Box::new(err),
backtrace,
}
} }
} }
/// Compatibility for `failure::Error` /// Compatibility for `failure::Error`
impl<T> ResponseError for failure::Compat<T> impl<T> ResponseError for failure::Compat<T>
where T: fmt::Display + fmt::Debug + Sync + Send + 'static { } where
T: fmt::Display + fmt::Debug + Sync + Send + 'static,
{
}
impl From<failure::Error> for Error { impl From<failure::Error> for Error {
fn from(err: failure::Error) -> Error { fn from(err: failure::Error) -> Error {
@ -128,15 +138,11 @@ impl ResponseError for HttpError {}
/// Return `InternalServerError` for `io::Error` /// Return `InternalServerError` for `io::Error`
impl ResponseError for io::Error { impl ResponseError for io::Error {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match self.kind() { match self.kind() {
io::ErrorKind::NotFound => io::ErrorKind::NotFound => HttpResponse::new(StatusCode::NOT_FOUND),
HttpResponse::new(StatusCode::NOT_FOUND), io::ErrorKind::PermissionDenied => HttpResponse::new(StatusCode::FORBIDDEN),
io::ErrorKind::PermissionDenied => _ => HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR),
HttpResponse::new(StatusCode::FORBIDDEN),
_ =>
HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
} }
} }
} }
@ -145,7 +151,7 @@ impl ResponseError for io::Error {
impl ResponseError for header::InvalidHeaderValue { impl ResponseError for header::InvalidHeaderValue {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::new(StatusCode::BAD_REQUEST) HttpResponse::new(StatusCode::BAD_REQUEST)
} }
} }
/// `BadRequest` for `InvalidHeaderValue` /// `BadRequest` for `InvalidHeaderValue`
@ -165,35 +171,36 @@ impl ResponseError for MailboxError {}
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum ParseError { pub enum ParseError {
/// An invalid `Method`, such as `GE.T`. /// An invalid `Method`, such as `GE.T`.
#[fail(display="Invalid Method specified")] #[fail(display = "Invalid Method specified")]
Method, Method,
/// An invalid `Uri`, such as `exam ple.domain`. /// An invalid `Uri`, such as `exam ple.domain`.
#[fail(display="Uri error: {}", _0)] #[fail(display = "Uri error: {}", _0)]
Uri(InvalidUri), Uri(InvalidUri),
/// An invalid `HttpVersion`, such as `HTP/1.1` /// An invalid `HttpVersion`, such as `HTP/1.1`
#[fail(display="Invalid HTTP version specified")] #[fail(display = "Invalid HTTP version specified")]
Version, Version,
/// An invalid `Header`. /// An invalid `Header`.
#[fail(display="Invalid Header provided")] #[fail(display = "Invalid Header provided")]
Header, Header,
/// A message head is too large to be reasonable. /// A message head is too large to be reasonable.
#[fail(display="Message head is too large")] #[fail(display = "Message head is too large")]
TooLarge, TooLarge,
/// A message reached EOF, but is not complete. /// A message reached EOF, but is not complete.
#[fail(display="Message is incomplete")] #[fail(display = "Message is incomplete")]
Incomplete, Incomplete,
/// An invalid `Status`, such as `1337 ELITE`. /// An invalid `Status`, such as `1337 ELITE`.
#[fail(display="Invalid Status provided")] #[fail(display = "Invalid Status provided")]
Status, Status,
/// A timeout occurred waiting for an IO event. /// A timeout occurred waiting for an IO event.
#[allow(dead_code)] #[allow(dead_code)]
#[fail(display="Timeout")] #[fail(display = "Timeout")]
Timeout, Timeout,
/// An `io::Error` that occurred while trying to read or write to a network stream. /// An `io::Error` that occurred while trying to read or write to a network
#[fail(display="IO error: {}", _0)] /// stream.
#[fail(display = "IO error: {}", _0)]
Io(#[cause] IoError), Io(#[cause] IoError),
/// Parsing a field as string failed /// Parsing a field as string failed
#[fail(display="UTF8 error: {}", _0)] #[fail(display = "UTF8 error: {}", _0)]
Utf8(#[cause] Utf8Error), Utf8(#[cause] Utf8Error),
} }
@ -231,8 +238,10 @@ impl From<FromUtf8Error> for ParseError {
impl From<httparse::Error> for ParseError { impl From<httparse::Error> for ParseError {
fn from(err: httparse::Error) -> ParseError { fn from(err: httparse::Error) -> ParseError {
match err { match err {
httparse::Error::HeaderName | httparse::Error::HeaderValue | httparse::Error::HeaderName
httparse::Error::NewLine | httparse::Error::Token => ParseError::Header, | httparse::Error::HeaderValue
| httparse::Error::NewLine
| httparse::Error::Token => ParseError::Header,
httparse::Error::Status => ParseError::Status, httparse::Error::Status => ParseError::Status,
httparse::Error::TooManyHeaders => ParseError::TooLarge, httparse::Error::TooManyHeaders => ParseError::TooLarge,
httparse::Error::Version => ParseError::Version, httparse::Error::Version => ParseError::Version,
@ -244,22 +253,22 @@ impl From<httparse::Error> for ParseError {
/// A set of errors that can occur during payload parsing /// A set of errors that can occur during payload parsing
pub enum PayloadError { pub enum PayloadError {
/// A payload reached EOF, but is not complete. /// A payload reached EOF, but is not complete.
#[fail(display="A payload reached EOF, but is not complete.")] #[fail(display = "A payload reached EOF, but is not complete.")]
Incomplete, Incomplete,
/// Content encoding stream corruption /// Content encoding stream corruption
#[fail(display="Can not decode content-encoding.")] #[fail(display = "Can not decode content-encoding.")]
EncodingCorrupted, EncodingCorrupted,
/// A payload reached size limit. /// A payload reached size limit.
#[fail(display="A payload reached size limit.")] #[fail(display = "A payload reached size limit.")]
Overflow, Overflow,
/// A payload length is unknown. /// A payload length is unknown.
#[fail(display="A payload length is unknown.")] #[fail(display = "A payload length is unknown.")]
UnknownLength, UnknownLength,
/// Io error /// Io error
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Io(#[cause] IoError), Io(#[cause] IoError),
/// Http2 error /// Http2 error
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Http2(#[cause] Http2Error), Http2(#[cause] Http2Error),
} }
@ -283,12 +292,12 @@ impl ResponseError for cookie::ParseError {
#[derive(Fail, PartialEq, Debug)] #[derive(Fail, PartialEq, Debug)]
pub enum HttpRangeError { pub enum HttpRangeError {
/// Returned if range is invalid. /// Returned if range is invalid.
#[fail(display="Range header is invalid")] #[fail(display = "Range header is invalid")]
InvalidRange, InvalidRange,
/// Returned if first-byte-pos of all of the byte-range-spec /// Returned if first-byte-pos of all of the byte-range-spec
/// values is greater than the content size. /// values is greater than the content size.
/// See `https://github.com/golang/go/commit/aa9b3d7` /// See `https://github.com/golang/go/commit/aa9b3d7`
#[fail(display="First-byte-pos of all of the byte-range-spec values is greater than the content size")] #[fail(display = "First-byte-pos of all of the byte-range-spec values is greater than the content size")]
NoOverlap, NoOverlap,
} }
@ -296,7 +305,9 @@ pub enum HttpRangeError {
impl ResponseError for HttpRangeError { impl ResponseError for HttpRangeError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::with_body( HttpResponse::with_body(
StatusCode::BAD_REQUEST, "Invalid Range header provided") StatusCode::BAD_REQUEST,
"Invalid Range header provided",
)
} }
} }
@ -313,22 +324,22 @@ impl From<HttpRangeParseError> for HttpRangeError {
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum MultipartError { pub enum MultipartError {
/// Content-Type header is not found /// Content-Type header is not found
#[fail(display="No Content-type header found")] #[fail(display = "No Content-type header found")]
NoContentType, NoContentType,
/// Can not parse Content-Type header /// Can not parse Content-Type header
#[fail(display="Can not parse Content-Type header")] #[fail(display = "Can not parse Content-Type header")]
ParseContentType, ParseContentType,
/// Multipart boundary is not found /// Multipart boundary is not found
#[fail(display="Multipart boundary is not found")] #[fail(display = "Multipart boundary is not found")]
Boundary, Boundary,
/// Multipart stream is incomplete /// Multipart stream is incomplete
#[fail(display="Multipart stream is incomplete")] #[fail(display = "Multipart stream is incomplete")]
Incomplete, Incomplete,
/// Error during field parsing /// Error during field parsing
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Parse(#[cause] ParseError), Parse(#[cause] ParseError),
/// Payload error /// Payload error
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Payload(#[cause] PayloadError), Payload(#[cause] PayloadError),
} }
@ -346,7 +357,6 @@ impl From<PayloadError> for MultipartError {
/// Return `BadRequest` for `MultipartError` /// Return `BadRequest` for `MultipartError`
impl ResponseError for MultipartError { impl ResponseError for MultipartError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::new(StatusCode::BAD_REQUEST) HttpResponse::new(StatusCode::BAD_REQUEST)
} }
@ -356,10 +366,10 @@ impl ResponseError for MultipartError {
#[derive(Fail, PartialEq, Debug)] #[derive(Fail, PartialEq, Debug)]
pub enum ExpectError { pub enum ExpectError {
/// Expect header value can not be converted to utf8 /// Expect header value can not be converted to utf8
#[fail(display="Expect header value can not be converted to utf8")] #[fail(display = "Expect header value can not be converted to utf8")]
Encoding, Encoding,
/// Unknown expect value /// Unknown expect value
#[fail(display="Unknown expect value")] #[fail(display = "Unknown expect value")]
UnknownExpect, UnknownExpect,
} }
@ -373,10 +383,10 @@ impl ResponseError for ExpectError {
#[derive(Fail, PartialEq, Debug)] #[derive(Fail, PartialEq, Debug)]
pub enum ContentTypeError { pub enum ContentTypeError {
/// Can not parse content type /// Can not parse content type
#[fail(display="Can not parse content type")] #[fail(display = "Can not parse content type")]
ParseError, ParseError,
/// Unknown content encoding /// Unknown content encoding
#[fail(display="Unknown content encoding")] #[fail(display = "Unknown content encoding")]
UnknownEncoding, UnknownEncoding,
} }
@ -391,36 +401,36 @@ impl ResponseError for ContentTypeError {
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum UrlencodedError { pub enum UrlencodedError {
/// Can not decode chunked transfer encoding /// Can not decode chunked transfer encoding
#[fail(display="Can not decode chunked transfer encoding")] #[fail(display = "Can not decode chunked transfer encoding")]
Chunked, Chunked,
/// Payload size is bigger than 256k /// Payload size is bigger than 256k
#[fail(display="Payload size is bigger than 256k")] #[fail(display = "Payload size is bigger than 256k")]
Overflow, Overflow,
/// Payload size is now known /// Payload size is now known
#[fail(display="Payload size is now known")] #[fail(display = "Payload size is now known")]
UnknownLength, UnknownLength,
/// Content type error /// Content type error
#[fail(display="Content type error")] #[fail(display = "Content type error")]
ContentType, ContentType,
/// Parse error /// Parse error
#[fail(display="Parse error")] #[fail(display = "Parse error")]
Parse, Parse,
/// Payload error /// Payload error
#[fail(display="Error that occur during reading payload: {}", _0)] #[fail(display = "Error that occur during reading payload: {}", _0)]
Payload(#[cause] PayloadError), Payload(#[cause] PayloadError),
} }
/// Return `BadRequest` for `UrlencodedError` /// Return `BadRequest` for `UrlencodedError`
impl ResponseError for UrlencodedError { impl ResponseError for UrlencodedError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match *self { match *self {
UrlencodedError::Overflow => UrlencodedError::Overflow => {
HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE), HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE)
UrlencodedError::UnknownLength => }
HttpResponse::new(StatusCode::LENGTH_REQUIRED), UrlencodedError::UnknownLength => {
_ => HttpResponse::new(StatusCode::LENGTH_REQUIRED)
HttpResponse::new(StatusCode::BAD_REQUEST), }
_ => HttpResponse::new(StatusCode::BAD_REQUEST),
} }
} }
} }
@ -435,28 +445,27 @@ impl From<PayloadError> for UrlencodedError {
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum JsonPayloadError { pub enum JsonPayloadError {
/// Payload size is bigger than 256k /// Payload size is bigger than 256k
#[fail(display="Payload size is bigger than 256k")] #[fail(display = "Payload size is bigger than 256k")]
Overflow, Overflow,
/// Content type error /// Content type error
#[fail(display="Content type error")] #[fail(display = "Content type error")]
ContentType, ContentType,
/// Deserialize error /// Deserialize error
#[fail(display="Json deserialize error: {}", _0)] #[fail(display = "Json deserialize error: {}", _0)]
Deserialize(#[cause] JsonError), Deserialize(#[cause] JsonError),
/// Payload error /// Payload error
#[fail(display="Error that occur during reading payload: {}", _0)] #[fail(display = "Error that occur during reading payload: {}", _0)]
Payload(#[cause] PayloadError), Payload(#[cause] PayloadError),
} }
/// Return `BadRequest` for `UrlencodedError` /// Return `BadRequest` for `UrlencodedError`
impl ResponseError for JsonPayloadError { impl ResponseError for JsonPayloadError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match *self { match *self {
JsonPayloadError::Overflow => JsonPayloadError::Overflow => {
HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE), HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE)
_ => }
HttpResponse::new(StatusCode::BAD_REQUEST), _ => HttpResponse::new(StatusCode::BAD_REQUEST),
} }
} }
} }
@ -478,19 +487,18 @@ impl From<JsonError> for JsonPayloadError {
#[derive(Fail, Debug, PartialEq)] #[derive(Fail, Debug, PartialEq)]
pub enum UriSegmentError { pub enum UriSegmentError {
/// The segment started with the wrapped invalid character. /// The segment started with the wrapped invalid character.
#[fail(display="The segment started with the wrapped invalid character")] #[fail(display = "The segment started with the wrapped invalid character")]
BadStart(char), BadStart(char),
/// The segment contained the wrapped invalid character. /// The segment contained the wrapped invalid character.
#[fail(display="The segment contained the wrapped invalid character")] #[fail(display = "The segment contained the wrapped invalid character")]
BadChar(char), BadChar(char),
/// The segment ended with the wrapped invalid character. /// The segment ended with the wrapped invalid character.
#[fail(display="The segment ended with the wrapped invalid character")] #[fail(display = "The segment ended with the wrapped invalid character")]
BadEnd(char), BadEnd(char),
} }
/// Return `BadRequest` for `UriSegmentError` /// Return `BadRequest` for `UriSegmentError`
impl ResponseError for UriSegmentError { impl ResponseError for UriSegmentError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::new(StatusCode::BAD_REQUEST) HttpResponse::new(StatusCode::BAD_REQUEST)
} }
@ -499,13 +507,13 @@ impl ResponseError for UriSegmentError {
/// Errors which can occur when attempting to generate resource uri. /// Errors which can occur when attempting to generate resource uri.
#[derive(Fail, Debug, PartialEq)] #[derive(Fail, Debug, PartialEq)]
pub enum UrlGenerationError { pub enum UrlGenerationError {
#[fail(display="Resource not found")] #[fail(display = "Resource not found")]
ResourceNotFound, ResourceNotFound,
#[fail(display="Not all path pattern covered")] #[fail(display = "Not all path pattern covered")]
NotEnoughElements, NotEnoughElements,
#[fail(display="Router is not available")] #[fail(display = "Router is not available")]
RouterNotAvailable, RouterNotAvailable,
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
ParseError(#[cause] UrlParseError), ParseError(#[cause] UrlParseError),
} }
@ -520,8 +528,9 @@ impl From<UrlParseError> for UrlGenerationError {
/// Helper type that can wrap any error and generate custom response. /// Helper type that can wrap any error and generate custom response.
/// ///
/// In following example any `io::Error` will be converted into "BAD REQUEST" response /// In following example any `io::Error` will be converted into "BAD REQUEST"
/// as opposite to *INNTERNAL SERVER ERROR* which is defined by default. /// response as opposite to *INNTERNAL SERVER ERROR* which is defined by
/// default.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_web; /// # extern crate actix_web;
@ -554,7 +563,8 @@ impl<T> InternalError<T> {
} }
impl<T> Fail for InternalError<T> impl<T> Fail for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
fn backtrace(&self) -> Option<&Backtrace> { fn backtrace(&self) -> Option<&Backtrace> {
Some(&self.backtrace) Some(&self.backtrace)
@ -562,7 +572,8 @@ impl<T> Fail for InternalError<T>
} }
impl<T> fmt::Debug for InternalError<T> impl<T> fmt::Debug for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.cause, f) fmt::Debug::fmt(&self.cause, f)
@ -570,7 +581,8 @@ impl<T> fmt::Debug for InternalError<T>
} }
impl<T> fmt::Display for InternalError<T> impl<T> fmt::Display for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.cause, f) fmt::Debug::fmt(&self.cause, f)
@ -578,7 +590,8 @@ impl<T> fmt::Display for InternalError<T>
} }
impl<T> ResponseError for InternalError<T> impl<T> ResponseError for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::new(self.status) HttpResponse::new(self.status)
@ -586,7 +599,8 @@ impl<T> ResponseError for InternalError<T>
} }
impl<T> Responder for InternalError<T> impl<T> Responder for InternalError<T>
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
type Item = HttpResponse; type Item = HttpResponse;
type Error = Error; type Error = Error;
@ -596,82 +610,102 @@ impl<T> Responder for InternalError<T>
} }
} }
/// Helper function that creates wrapper of any error and generate *BAD REQUEST* response. /// Helper function that creates wrapper of any error and generate *BAD
/// REQUEST* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorBadRequest<T>(err: T) -> Error pub fn ErrorBadRequest<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::BAD_REQUEST).into() InternalError::new(err, StatusCode::BAD_REQUEST).into()
} }
/// Helper function that creates wrapper of any error and generate *UNAUTHORIZED* response. /// Helper function that creates wrapper of any error and generate
/// *UNAUTHORIZED* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorUnauthorized<T>(err: T) -> Error pub fn ErrorUnauthorized<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::UNAUTHORIZED).into() InternalError::new(err, StatusCode::UNAUTHORIZED).into()
} }
/// Helper function that creates wrapper of any error and generate *FORBIDDEN* response. /// Helper function that creates wrapper of any error and generate *FORBIDDEN*
/// response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorForbidden<T>(err: T) -> Error pub fn ErrorForbidden<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::FORBIDDEN).into() InternalError::new(err, StatusCode::FORBIDDEN).into()
} }
/// Helper function that creates wrapper of any error and generate *NOT FOUND* response. /// Helper function that creates wrapper of any error and generate *NOT FOUND*
/// response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorNotFound<T>(err: T) -> Error pub fn ErrorNotFound<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::NOT_FOUND).into() InternalError::new(err, StatusCode::NOT_FOUND).into()
} }
/// Helper function that creates wrapper of any error and generate *METHOD NOT ALLOWED* response. /// Helper function that creates wrapper of any error and generate *METHOD NOT
/// ALLOWED* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorMethodNotAllowed<T>(err: T) -> Error pub fn ErrorMethodNotAllowed<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into()
} }
/// Helper function that creates wrapper of any error and generate *REQUEST TIMEOUT* response. /// Helper function that creates wrapper of any error and generate *REQUEST
/// TIMEOUT* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorRequestTimeout<T>(err: T) -> Error pub fn ErrorRequestTimeout<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into() InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into()
} }
/// Helper function that creates wrapper of any error and generate *CONFLICT* response. /// Helper function that creates wrapper of any error and generate *CONFLICT*
/// response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorConflict<T>(err: T) -> Error pub fn ErrorConflict<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::CONFLICT).into() InternalError::new(err, StatusCode::CONFLICT).into()
} }
/// Helper function that creates wrapper of any error and generate *GONE* response. /// Helper function that creates wrapper of any error and generate *GONE*
/// response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorGone<T>(err: T) -> Error pub fn ErrorGone<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::GONE).into() InternalError::new(err, StatusCode::GONE).into()
} }
/// Helper function that creates wrapper of any error and generate *PRECONDITION FAILED* response. /// Helper function that creates wrapper of any error and generate
/// *PRECONDITION FAILED* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorPreconditionFailed<T>(err: T) -> Error pub fn ErrorPreconditionFailed<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() InternalError::new(err, StatusCode::PRECONDITION_FAILED).into()
} }
/// Helper function that creates wrapper of any error and generate *EXPECTATION FAILED* response. /// Helper function that creates wrapper of any error and generate
/// *EXPECTATION FAILED* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorExpectationFailed<T>(err: T) -> Error pub fn ErrorExpectationFailed<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() InternalError::new(err, StatusCode::EXPECTATION_FAILED).into()
} }
@ -680,26 +714,28 @@ pub fn ErrorExpectationFailed<T>(err: T) -> Error
/// generate *INTERNAL SERVER ERROR* response. /// generate *INTERNAL SERVER ERROR* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn ErrorInternalServerError<T>(err: T) -> Error pub fn ErrorInternalServerError<T>(err: T) -> Error
where T: Send + Sync + fmt::Debug + 'static where
T: Send + Sync + fmt::Debug + 'static,
{ {
InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into() InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into()
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use cookie::ParseError as CookieParseError;
use failure;
use http::{Error as HttpError, StatusCode};
use httparse;
use std::env; use std::env;
use std::error::Error as StdError; use std::error::Error as StdError;
use std::io; use std::io;
use httparse;
use http::{StatusCode, Error as HttpError};
use cookie::ParseError as CookieParseError;
use failure;
use super::*;
#[test] #[test]
#[cfg(actix_nightly)] #[cfg(actix_nightly)]
fn test_nightly() { fn test_nightly() {
let resp: HttpResponse = IoError::new(io::ErrorKind::Other, "test").error_response(); let resp: HttpResponse =
IoError::new(io::ErrorKind::Other, "test").error_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
} }
@ -775,10 +811,10 @@ mod tests {
match ParseError::from($from) { match ParseError::from($from) {
e @ $error => { e @ $error => {
assert!(format!("{}", e).len() >= 5); assert!(format!("{}", e).len() >= 5);
} , }
e => unreachable!("{:?}", e) e => unreachable!("{:?}", e),
} }
} };
} }
macro_rules! from_and_cause { macro_rules! from_and_cause {
@ -787,10 +823,10 @@ mod tests {
e @ $error => { e @ $error => {
let desc = format!("{}", e.cause().unwrap()); let desc = format!("{}", e.cause().unwrap());
assert_eq!(desc, $from.description().to_owned()); assert_eq!(desc, $from.description().to_owned());
}, }
_ => unreachable!("{:?}", $from) _ => unreachable!("{:?}", $from),
} }
} };
} }
#[test] #[test]
@ -814,7 +850,10 @@ mod tests {
env::set_var(NAME, "0"); env::set_var(NAME, "0");
let error = failure::err_msg("Hello!"); let error = failure::err_msg("Hello!");
let resp: Error = error.into(); let resp: Error = error.into();
assert_eq!(format!("{:?}", resp), "Compat { error: ErrorMessage { msg: \"Hello!\" } }\n\n"); assert_eq!(
format!("{:?}", resp),
"Compat { error: ErrorMessage { msg: \"Hello!\" } }\n\n"
);
match old_tb { match old_tb {
Ok(x) => env::set_var(NAME, x), Ok(x) => env::set_var(NAME, x),
_ => env::remove_var(NAME), _ => env::remove_var(NAME),

View File

@ -1,19 +1,19 @@
use std::str;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::str;
use mime::Mime;
use bytes::Bytes; use bytes::Bytes;
use serde_urlencoded;
use serde::de::{self, DeserializeOwned};
use futures::future::{Future, FutureResult, result};
use encoding::all::UTF_8; use encoding::all::UTF_8;
use encoding::types::{Encoding, DecoderTrap}; use encoding::types::{DecoderTrap, Encoding};
use futures::future::{result, Future, FutureResult};
use mime::Mime;
use serde::de::{self, DeserializeOwned};
use serde_urlencoded;
use de::PathDeserializer;
use error::{Error, ErrorBadRequest}; use error::{Error, ErrorBadRequest};
use handler::{Either, FromRequest}; use handler::{Either, FromRequest};
use httprequest::HttpRequest;
use httpmessage::{HttpMessage, MessageBody, UrlEncoded}; use httpmessage::{HttpMessage, MessageBody, UrlEncoded};
use de::PathDeserializer; use httprequest::HttpRequest;
/// Extract typed information from the request's path. /// Extract typed information from the request's path.
/// ///
@ -39,8 +39,8 @@ use de::PathDeserializer;
/// } /// }
/// ``` /// ```
/// ///
/// It is possible to extract path information to a specific type that implements /// It is possible to extract path information to a specific type that
/// `Deserialize` trait from *serde*. /// implements `Deserialize` trait from *serde*.
/// ///
/// ```rust /// ```rust
/// # extern crate bytes; /// # extern crate bytes;
@ -65,12 +65,11 @@ use de::PathDeserializer;
/// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor /// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor
/// } /// }
/// ``` /// ```
pub struct Path<T>{ pub struct Path<T> {
inner: T inner: T,
} }
impl<T> AsRef<T> for Path<T> { impl<T> AsRef<T> for Path<T> {
fn as_ref(&self) -> &T { fn as_ref(&self) -> &T {
&self.inner &self.inner
} }
@ -98,7 +97,9 @@ impl<T> Path<T> {
} }
impl<T, S> FromRequest<S> for Path<T> impl<T, S> FromRequest<S> for Path<T>
where T: DeserializeOwned, S: 'static where
T: DeserializeOwned,
S: 'static,
{ {
type Config = (); type Config = ();
type Result = FutureResult<Self, Error>; type Result = FutureResult<Self, Error>;
@ -106,9 +107,11 @@ impl<T, S> FromRequest<S> for Path<T>
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result {
let req = req.clone(); let req = req.clone();
result(de::Deserialize::deserialize(PathDeserializer::new(&req)) result(
.map_err(|e| e.into()) de::Deserialize::deserialize(PathDeserializer::new(&req))
.map(|inner| Path{inner})) .map_err(|e| e.into())
.map(|inner| Path { inner }),
)
} }
} }
@ -164,7 +167,9 @@ impl<T> Query<T> {
} }
impl<T, S> FromRequest<S> for Query<T> impl<T, S> FromRequest<S> for Query<T>
where T: de::DeserializeOwned, S: 'static where
T: de::DeserializeOwned,
S: 'static,
{ {
type Config = (); type Config = ();
type Result = FutureResult<Self, Error>; type Result = FutureResult<Self, Error>;
@ -172,24 +177,26 @@ impl<T, S> FromRequest<S> for Query<T>
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result {
let req = req.clone(); let req = req.clone();
result(serde_urlencoded::from_str::<T>(req.query_string()) result(
.map_err(|e| e.into()) serde_urlencoded::from_str::<T>(req.query_string())
.map(Query)) .map_err(|e| e.into())
.map(Query),
)
} }
} }
/// Extract typed information from the request's body. /// Extract typed information from the request's body.
/// ///
/// To extract typed information from request's body, the type `T` must implement the /// To extract typed information from request's body, the type `T` must
/// `Deserialize` trait from *serde*. /// implement the `Deserialize` trait from *serde*.
/// ///
/// [**FormConfig**](dev/struct.FormConfig.html) allows to configure extraction /// [**FormConfig**](dev/struct.FormConfig.html) allows to configure extraction
/// process. /// process.
/// ///
/// ## Example /// ## Example
/// ///
/// It is possible to extract path information to a specific type that implements /// It is possible to extract path information to a specific type that
/// `Deserialize` trait from *serde*. /// implements `Deserialize` trait from *serde*.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_web; /// # extern crate actix_web;
@ -233,17 +240,21 @@ impl<T> DerefMut for Form<T> {
} }
impl<T, S> FromRequest<S> for Form<T> impl<T, S> FromRequest<S> for Form<T>
where T: DeserializeOwned + 'static, S: 'static where
T: DeserializeOwned + 'static,
S: 'static,
{ {
type Config = FormConfig; type Config = FormConfig;
type Result = Box<Future<Item=Self, Error=Error>>; type Result = Box<Future<Item = Self, Error = Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
Box::new(UrlEncoded::new(req.clone()) Box::new(
.limit(cfg.limit) UrlEncoded::new(req.clone())
.from_err() .limit(cfg.limit)
.map(Form)) .from_err()
.map(Form),
)
} }
} }
@ -279,7 +290,6 @@ pub struct FormConfig {
} }
impl FormConfig { impl FormConfig {
/// Change max size of payload. By default max size is 256Kb /// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self { pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit; self.limit = limit;
@ -289,7 +299,7 @@ impl FormConfig {
impl Default for FormConfig { impl Default for FormConfig {
fn default() -> Self { fn default() -> Self {
FormConfig{limit: 262_144} FormConfig { limit: 262_144 }
} }
} }
@ -297,8 +307,8 @@ impl Default for FormConfig {
/// ///
/// Loads request's payload and construct Bytes instance. /// Loads request's payload and construct Bytes instance.
/// ///
/// [**PayloadConfig**](dev/struct.PayloadConfig.html) allows to configure extraction /// [**PayloadConfig**](dev/struct.PayloadConfig.html) allows to configure
/// process. /// extraction process.
/// ///
/// ## Example /// ## Example
/// ///
@ -313,11 +323,10 @@ impl Default for FormConfig {
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
impl<S: 'static> FromRequest<S> for Bytes impl<S: 'static> FromRequest<S> for Bytes {
{
type Config = PayloadConfig; type Config = PayloadConfig;
type Result = Either<FutureResult<Self, Error>, type Result =
Box<Future<Item=Self, Error=Error>>>; Either<FutureResult<Self, Error>, Box<Future<Item = Self, Error = Error>>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
@ -326,9 +335,11 @@ impl<S: 'static> FromRequest<S> for Bytes
return Either::A(result(Err(e))); return Either::A(result(Err(e)));
} }
Either::B(Box::new(MessageBody::new(req.clone()) Either::B(Box::new(
.limit(cfg.limit) MessageBody::new(req.clone())
.from_err())) .limit(cfg.limit)
.from_err(),
))
} }
} }
@ -351,11 +362,10 @@ impl<S: 'static> FromRequest<S> for Bytes
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
impl<S: 'static> FromRequest<S> for String impl<S: 'static> FromRequest<S> for String {
{
type Config = PayloadConfig; type Config = PayloadConfig;
type Result = Either<FutureResult<String, Error>, type Result =
Box<Future<Item=String, Error=Error>>>; Either<FutureResult<String, Error>, Box<Future<Item = String, Error = Error>>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
@ -366,8 +376,11 @@ impl<S: 'static> FromRequest<S> for String
// check charset // check charset
let encoding = match req.encoding() { let encoding = match req.encoding() {
Err(_) => return Either::A( Err(_) => {
result(Err(ErrorBadRequest("Unknown request charset")))), return Either::A(result(Err(ErrorBadRequest(
"Unknown request charset",
))))
}
Ok(encoding) => encoding, Ok(encoding) => encoding,
}; };
@ -379,13 +392,15 @@ impl<S: 'static> FromRequest<S> for String
let enc: *const Encoding = encoding as *const Encoding; let enc: *const Encoding = encoding as *const Encoding;
if enc == UTF_8 { if enc == UTF_8 {
Ok(str::from_utf8(body.as_ref()) Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))? .map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned()) .to_owned())
} else { } else {
Ok(encoding.decode(&body, DecoderTrap::Strict) Ok(encoding
.map_err(|_| ErrorBadRequest("Can not decode body"))?) .decode(&body, DecoderTrap::Strict)
.map_err(|_| ErrorBadRequest("Can not decode body"))?)
} }
}))) }),
))
} }
} }
@ -396,14 +411,14 @@ pub struct PayloadConfig {
} }
impl PayloadConfig { impl PayloadConfig {
/// Change max size of payload. By default max size is 256Kb /// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self { pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit; self.limit = limit;
self self
} }
/// Set required mime-type of the request. By default mime type is not enforced. /// Set required mime-type of the request. By default mime type is not
/// enforced.
pub fn mimetype(&mut self, mt: Mime) -> &mut Self { pub fn mimetype(&mut self, mt: Mime) -> &mut Self {
self.mimetype = Some(mt); self.mimetype = Some(mt);
self self
@ -417,13 +432,13 @@ impl PayloadConfig {
if mt != req_mt { if mt != req_mt {
return Err(ErrorBadRequest("Unexpected Content-Type")); return Err(ErrorBadRequest("Unexpected Content-Type"));
} }
}, }
Ok(None) => { Ok(None) => {
return Err(ErrorBadRequest("Content-Type is expected")); return Err(ErrorBadRequest("Content-Type is expected"));
}, }
Err(err) => { Err(err) => {
return Err(err.into()); return Err(err.into());
}, }
} }
} }
Ok(()) Ok(())
@ -432,21 +447,24 @@ impl PayloadConfig {
impl Default for PayloadConfig { impl Default for PayloadConfig {
fn default() -> Self { fn default() -> Self {
PayloadConfig{limit: 262_144, mimetype: None} PayloadConfig {
limit: 262_144,
mimetype: None,
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use mime;
use bytes::Bytes; use bytes::Bytes;
use futures::{Async, Future}; use futures::{Async, Future};
use http::header; use http::header;
use router::{Router, Resource}; use mime;
use resource::ResourceHandler; use resource::ResourceHandler;
use test::TestRequest; use router::{Resource, Router};
use server::ServerSettings; use server::ServerSettings;
use test::TestRequest;
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
struct Info { struct Info {
@ -457,12 +475,13 @@ mod tests {
fn test_bytes() { fn test_bytes() {
let cfg = PayloadConfig::default(); let cfg = PayloadConfig::default();
let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
match Bytes::from_request(&req, &cfg).poll().unwrap() { match Bytes::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s, Bytes::from_static(b"hello=world")); assert_eq!(s, Bytes::from_static(b"hello=world"));
}, }
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -471,12 +490,13 @@ mod tests {
fn test_string() { fn test_string() {
let cfg = PayloadConfig::default(); let cfg = PayloadConfig::default();
let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
match String::from_request(&req, &cfg).poll().unwrap() { match String::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s, "hello=world"); assert_eq!(s, "hello=world");
}, }
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -484,17 +504,19 @@ mod tests {
#[test] #[test]
fn test_form() { fn test_form() {
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded") header::CONTENT_TYPE,
.header(header::CONTENT_LENGTH, "11") "application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "11")
.finish(); .finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
let mut cfg = FormConfig::default(); let mut cfg = FormConfig::default();
cfg.limit(4096); cfg.limit(4096);
match Form::<Info>::from_request(&req, &cfg).poll().unwrap() { match Form::<Info>::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.hello, "world"); assert_eq!(s.hello, "world");
}, }
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -507,10 +529,13 @@ mod tests {
assert!(cfg.check_mimetype(&req).is_err()); assert!(cfg.check_mimetype(&req).is_err());
let req = TestRequest::with_header( let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded").finish(); header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
).finish();
assert!(cfg.check_mimetype(&req).is_err()); assert!(cfg.check_mimetype(&req).is_err());
let req = TestRequest::with_header(header::CONTENT_TYPE, "application/json").finish(); let req =
TestRequest::with_header(header::CONTENT_TYPE, "application/json").finish();
assert!(cfg.check_mimetype(&req).is_ok()); assert!(cfg.check_mimetype(&req).is_ok());
} }
@ -538,30 +563,39 @@ mod tests {
let mut resource = ResourceHandler::<()>::default(); let mut resource = ResourceHandler::<()>::default();
resource.name("index"); resource.name("index");
let mut routes = Vec::new(); let mut routes = Vec::new();
routes.push((Resource::new("index", "/{key}/{value}/"), Some(resource))); routes.push((
Resource::new("index", "/{key}/{value}/"),
Some(resource),
));
let (router, _) = Router::new("", ServerSettings::default(), routes); let (router, _) = Router::new("", ServerSettings::default(), routes);
assert!(router.recognize(&mut req).is_some()); assert!(router.recognize(&mut req).is_some());
match Path::<MyStruct>::from_request(&req, &()).poll().unwrap() { match Path::<MyStruct>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.key, "name"); assert_eq!(s.key, "name");
assert_eq!(s.value, "user1"); assert_eq!(s.value, "user1");
}, }
_ => unreachable!(), _ => unreachable!(),
} }
match Path::<(String, String)>::from_request(&req, &()).poll().unwrap() { match Path::<(String, String)>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.0, "name"); assert_eq!(s.0, "name");
assert_eq!(s.1, "user1"); assert_eq!(s.1, "user1");
}, }
_ => unreachable!(), _ => unreachable!(),
} }
match Query::<Id>::from_request(&req, &()).poll().unwrap() { match Query::<Id>::from_request(&req, &()).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.id, "test"); assert_eq!(s.id, "test");
}, }
_ => unreachable!(), _ => unreachable!(),
} }
@ -572,22 +606,31 @@ mod tests {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.as_ref().key, "name"); assert_eq!(s.as_ref().key, "name");
assert_eq!(s.value, 32); assert_eq!(s.value, 32);
}, }
_ => unreachable!(), _ => unreachable!(),
} }
match Path::<(String, u8)>::from_request(&req, &()).poll().unwrap() { match Path::<(String, u8)>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.0, "name"); assert_eq!(s.0, "name");
assert_eq!(s.1, 32); assert_eq!(s.1, 32);
}, }
_ => unreachable!(), _ => unreachable!(),
} }
match Path::<Vec<String>>::from_request(&req, &()).poll().unwrap() { match Path::<Vec<String>>::from_request(&req, &())
.poll()
.unwrap()
{
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.into_inner(), vec!["name".to_owned(), "32".to_owned()]); assert_eq!(
}, s.into_inner(),
vec!["name".to_owned(), "32".to_owned()]
);
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -606,7 +649,7 @@ mod tests {
match Path::<i8>::from_request(&req, &()).poll().unwrap() { match Path::<i8>::from_request(&req, &()).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s.into_inner(), 32); assert_eq!(s.into_inner(), 32);
}, }
_ => unreachable!(), _ => unreachable!(),
} }
} }

318
src/fs.rs
View File

@ -1,30 +1,30 @@
//! Static files support //! Static files support
use std::{io, cmp};
use std::io::{Read, Seek};
use std::fmt::Write; use std::fmt::Write;
use std::fs::{File, DirEntry, Metadata}; use std::fs::{DirEntry, File, Metadata};
use std::path::{Path, PathBuf}; use std::io::{Read, Seek};
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::time::{SystemTime, UNIX_EPOCH}; use std::path::{Path, PathBuf};
use std::sync::Mutex; use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use std::{cmp, io};
#[cfg(unix)] #[cfg(unix)]
use std::os::unix::fs::MetadataExt; use std::os::unix::fs::MetadataExt;
use bytes::{Bytes, BytesMut, BufMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures::{Async, Poll, Future, Stream}; use futures::{Async, Future, Poll, Stream};
use futures_cpupool::{CpuPool, CpuFuture}; use futures_cpupool::{CpuFuture, CpuPool};
use mime_guess::get_mime_type; use mime_guess::get_mime_type;
use percent_encoding::percent_decode; use percent_encoding::percent_decode;
use header;
use error::Error; use error::Error;
use param::FromParam; use handler::{Handler, Reply, Responder, RouteHandler, WrapHandler};
use handler::{Handler, RouteHandler, WrapHandler, Responder, Reply}; use header;
use http::{Method, StatusCode}; use http::{Method, StatusCode};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use param::FromParam;
/// A file with an associated name; responds with the Content-Type based on the /// A file with an associated name; responds with the Content-Type based on the
/// file extension. /// file extension.
@ -55,9 +55,15 @@ impl NamedFile {
let path = path.as_ref().to_path_buf(); let path = path.as_ref().to_path_buf();
let modified = md.modified().ok(); let modified = md.modified().ok();
let cpu_pool = None; let cpu_pool = None;
Ok(NamedFile{path, file, md, modified, cpu_pool, Ok(NamedFile {
only_get: false, path,
status_code: StatusCode::OK}) file,
md,
modified,
cpu_pool,
only_get: false,
status_code: StatusCode::OK,
})
} }
/// Allow only GET and HEAD methods /// Allow only GET and HEAD methods
@ -110,17 +116,25 @@ impl NamedFile {
self.modified.as_ref().map(|mtime| { self.modified.as_ref().map(|mtime| {
let ino = { let ino = {
#[cfg(unix)] #[cfg(unix)]
{ self.md.ino() } {
self.md.ino()
}
#[cfg(not(unix))] #[cfg(not(unix))]
{ 0 } {
0
}
}; };
let dur = mtime.duration_since(UNIX_EPOCH) let dur = mtime
.duration_since(UNIX_EPOCH)
.expect("modification time must be after epoch"); .expect("modification time must be after epoch");
header::EntityTag::strong( header::EntityTag::strong(format!(
format!("{:x}:{:x}:{:x}:{:x}", "{:x}:{:x}:{:x}:{:x}",
ino, self.md.len(), dur.as_secs(), ino,
dur.subsec_nanos())) self.md.len(),
dur.as_secs(),
dur.subsec_nanos()
))
}) })
} }
@ -178,7 +192,6 @@ fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
} }
} }
impl Responder for NamedFile { impl Responder for NamedFile {
type Item = HttpResponse; type Item = HttpResponse;
type Error = io::Error; type Error = io::Error;
@ -187,23 +200,27 @@ impl Responder for NamedFile {
if self.status_code != StatusCode::OK { if self.status_code != StatusCode::OK {
let mut resp = HttpResponse::build(self.status_code); let mut resp = HttpResponse::build(self.status_code);
resp.if_some(self.path().extension(), |ext, resp| { resp.if_some(self.path().extension(), |ext, resp| {
resp.set(header::ContentType(get_mime_type(&ext.to_string_lossy()))); resp.set(header::ContentType(get_mime_type(
&ext.to_string_lossy(),
)));
}); });
let reader = ChunkedReadFile { let reader = ChunkedReadFile {
size: self.md.len(), size: self.md.len(),
offset: 0, offset: 0,
cpu_pool: self.cpu_pool.unwrap_or_else(|| req.cpu_pool().clone()), cpu_pool: self.cpu_pool
.unwrap_or_else(|| req.cpu_pool().clone()),
file: Some(self.file), file: Some(self.file),
fut: None, fut: None,
}; };
return Ok(resp.streaming(reader)) return Ok(resp.streaming(reader));
} }
if self.only_get && *req.method() != Method::GET && *req.method() != Method::HEAD { if self.only_get && *req.method() != Method::GET && *req.method() != Method::HEAD
{
return Ok(HttpResponse::MethodNotAllowed() return Ok(HttpResponse::MethodNotAllowed()
.header(header::CONTENT_TYPE, "text/plain") .header(header::CONTENT_TYPE, "text/plain")
.header(header::ALLOW, "GET, HEAD") .header(header::ALLOW, "GET, HEAD")
.body("This resource only supports GET and HEAD.")) .body("This resource only supports GET and HEAD."));
} }
let etag = self.etag(); let etag = self.etag();
@ -233,17 +250,21 @@ impl Responder for NamedFile {
let mut resp = HttpResponse::build(self.status_code); let mut resp = HttpResponse::build(self.status_code);
resp resp.if_some(self.path().extension(), |ext, resp| {
.if_some(self.path().extension(), |ext, resp| { resp.set(header::ContentType(get_mime_type(
resp.set(header::ContentType(get_mime_type(&ext.to_string_lossy()))); &ext.to_string_lossy(),
)));
}).if_some(last_modified, |lm, resp| {
resp.set(header::LastModified(lm));
}) })
.if_some(last_modified, |lm, resp| {resp.set(header::LastModified(lm));}) .if_some(etag, |etag, resp| {
.if_some(etag, |etag, resp| {resp.set(header::ETag(etag));}); resp.set(header::ETag(etag));
});
if precondition_failed { if precondition_failed {
return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish()) return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish());
} else if not_modified { } else if not_modified {
return Ok(resp.status(StatusCode::NOT_MODIFIED).finish()) return Ok(resp.status(StatusCode::NOT_MODIFIED).finish());
} }
if *req.method() == Method::HEAD { if *req.method() == Method::HEAD {
@ -252,7 +273,8 @@ impl Responder for NamedFile {
let reader = ChunkedReadFile { let reader = ChunkedReadFile {
size: self.md.len(), size: self.md.len(),
offset: 0, offset: 0,
cpu_pool: self.cpu_pool.unwrap_or_else(|| req.cpu_pool().clone()), cpu_pool: self.cpu_pool
.unwrap_or_else(|| req.cpu_pool().clone()),
file: Some(self.file), file: Some(self.file),
fut: None, fut: None,
}; };
@ -273,7 +295,7 @@ pub struct ChunkedReadFile {
impl Stream for ChunkedReadFile { impl Stream for ChunkedReadFile {
type Item = Bytes; type Item = Bytes;
type Error= Error; type Error = Error;
fn poll(&mut self) -> Poll<Option<Bytes>, Error> { fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
if self.fut.is_some() { if self.fut.is_some() {
@ -283,7 +305,7 @@ impl Stream for ChunkedReadFile {
self.file = Some(file); self.file = Some(file);
self.offset += bytes.len() as u64; self.offset += bytes.len() as u64;
Ok(Async::Ready(Some(bytes))) Ok(Async::Ready(Some(bytes)))
}, }
Async::NotReady => Ok(Async::NotReady), Async::NotReady => Ok(Async::NotReady),
}; };
} }
@ -299,11 +321,11 @@ impl Stream for ChunkedReadFile {
let max_bytes = cmp::min(size.saturating_sub(offset), 65_536) as usize; let max_bytes = cmp::min(size.saturating_sub(offset), 65_536) as usize;
let mut buf = BytesMut::with_capacity(max_bytes); let mut buf = BytesMut::with_capacity(max_bytes);
file.seek(io::SeekFrom::Start(offset))?; file.seek(io::SeekFrom::Start(offset))?;
let nbytes = file.read(unsafe{buf.bytes_mut()})?; let nbytes = file.read(unsafe { buf.bytes_mut() })?;
if nbytes == 0 { if nbytes == 0 {
return Err(io::ErrorKind::UnexpectedEof.into()) return Err(io::ErrorKind::UnexpectedEof.into());
} }
unsafe{buf.advance_mut(nbytes)}; unsafe { buf.advance_mut(nbytes) };
Ok((file, buf.freeze())) Ok((file, buf.freeze()))
})); }));
self.poll() self.poll()
@ -313,9 +335,9 @@ impl Stream for ChunkedReadFile {
/// A directory; responds with the generated directory listing. /// A directory; responds with the generated directory listing.
#[derive(Debug)] #[derive(Debug)]
pub struct Directory{ pub struct Directory {
base: PathBuf, base: PathBuf,
path: PathBuf path: PathBuf,
} }
impl Directory { impl Directory {
@ -327,12 +349,12 @@ impl Directory {
if let Ok(ref entry) = *entry { if let Ok(ref entry) = *entry {
if let Some(name) = entry.file_name().to_str() { if let Some(name) = entry.file_name().to_str() {
if name.starts_with('.') { if name.starts_with('.') {
return false return false;
} }
} }
if let Ok(ref md) = entry.metadata() { if let Ok(ref md) = entry.metadata() {
let ft = md.file_type(); let ft = md.file_type();
return ft.is_dir() || ft.is_file() || ft.is_symlink() return ft.is_dir() || ft.is_file() || ft.is_symlink();
} }
} }
false false
@ -353,7 +375,7 @@ impl Responder for Directory {
let entry = entry.unwrap(); let entry = entry.unwrap();
let p = match entry.path().strip_prefix(&self.path) { let p = match entry.path().strip_prefix(&self.path) {
Ok(p) => base.join(p), Ok(p) => base.join(p),
Err(_) => continue Err(_) => continue,
}; };
// show file url as relative to static path // show file url as relative to static path
let file_url = format!("{}", p.to_string_lossy()); let file_url = format!("{}", p.to_string_lossy());
@ -361,27 +383,38 @@ impl Responder for Directory {
// if file is a directory, add '/' to the end of the name // if file is a directory, add '/' to the end of the name
if let Ok(metadata) = entry.metadata() { if let Ok(metadata) = entry.metadata() {
if metadata.is_dir() { if metadata.is_dir() {
let _ = write!(body, "<li><a href=\"{}\">{}/</a></li>", let _ = write!(
file_url, entry.file_name().to_string_lossy()); body,
"<li><a href=\"{}\">{}/</a></li>",
file_url,
entry.file_name().to_string_lossy()
);
} else { } else {
let _ = write!(body, "<li><a href=\"{}\">{}</a></li>", let _ = write!(
file_url, entry.file_name().to_string_lossy()); body,
"<li><a href=\"{}\">{}</a></li>",
file_url,
entry.file_name().to_string_lossy()
);
} }
} else { } else {
continue continue;
} }
} }
} }
let html = format!("<html>\ let html = format!(
<head><title>{}</title></head>\ "<html>\
<body><h1>{}</h1>\ <head><title>{}</title></head>\
<ul>\ <body><h1>{}</h1>\
{}\ <ul>\
</ul></body>\n</html>", index_of, index_of, body); {}\
</ul></body>\n</html>",
index_of, index_of, body
);
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
.content_type("text/html; charset=utf-8") .content_type("text/html; charset=utf-8")
.body(html)) .body(html))
} }
} }
@ -411,12 +444,11 @@ pub struct StaticFiles<S> {
_follow_symlinks: bool, _follow_symlinks: bool,
} }
lazy_static!{ lazy_static! {
static ref DEFAULT_CPUPOOL: Mutex<CpuPool> = Mutex::new(CpuPool::new(20)); static ref DEFAULT_CPUPOOL: Mutex<CpuPool> = Mutex::new(CpuPool::new(20));
} }
impl<S: 'static> StaticFiles<S> { impl<S: 'static> StaticFiles<S> {
/// Create new `StaticFiles` instance for specified base directory. /// Create new `StaticFiles` instance for specified base directory.
pub fn new<T: Into<PathBuf>>(dir: T) -> StaticFiles<S> { pub fn new<T: Into<PathBuf>>(dir: T) -> StaticFiles<S> {
let dir = dir.into(); let dir = dir.into();
@ -429,7 +461,7 @@ impl<S: 'static> StaticFiles<S> {
warn!("Is not directory `{:?}`", dir); warn!("Is not directory `{:?}`", dir);
(dir, false) (dir, false)
} }
}, }
Err(err) => { Err(err) => {
warn!("Static files directory `{:?}` error: {}", dir, err); warn!("Static files directory `{:?}` error: {}", dir, err);
(dir, false) (dir, false)
@ -437,9 +469,7 @@ impl<S: 'static> StaticFiles<S> {
}; };
// use default CpuPool // use default CpuPool
let pool = { let pool = { DEFAULT_CPUPOOL.lock().unwrap().clone() };
DEFAULT_CPUPOOL.lock().unwrap().clone()
};
StaticFiles { StaticFiles {
directory: dir, directory: dir,
@ -447,8 +477,9 @@ impl<S: 'static> StaticFiles<S> {
index: None, index: None,
show_index: false, show_index: false,
cpu_pool: pool, cpu_pool: pool,
default: Box::new(WrapHandler::new( default: Box::new(WrapHandler::new(|_| {
|_| HttpResponse::new(StatusCode::NOT_FOUND))), HttpResponse::new(StatusCode::NOT_FOUND)
})),
_chunk_size: 0, _chunk_size: 0,
_follow_symlinks: false, _follow_symlinks: false,
} }
@ -485,12 +516,13 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
if !self.accessible { if !self.accessible {
Ok(self.default.handle(req)) Ok(self.default.handle(req))
} else { } else {
let relpath = match req.match_info().get("tail").map( let relpath = match req.match_info()
|tail| percent_decode(tail.as_bytes()).decode_utf8().unwrap()) .get("tail")
.map(|tail| percent_decode(tail.as_bytes()).decode_utf8().unwrap())
.map(|tail| PathBuf::from_param(tail.as_ref())) .map(|tail| PathBuf::from_param(tail.as_ref()))
{ {
Some(Ok(path)) => path, Some(Ok(path)) => path,
_ => return Ok(self.default.handle(req)) _ => return Ok(self.default.handle(req)),
}; };
// full filepath // full filepath
@ -499,7 +531,8 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
if path.is_dir() { if path.is_dir() {
if let Some(ref redir_index) = self.index { if let Some(ref redir_index) = self.index {
// TODO: Don't redirect, just return the index content. // TODO: Don't redirect, just return the index content.
// TODO: It'd be nice if there were a good usable URL manipulation library // TODO: It'd be nice if there were a good usable URL manipulation
// library
let mut new_path: String = req.path().to_owned(); let mut new_path: String = req.path().to_owned();
for el in relpath.iter() { for el in relpath.iter() {
new_path.push_str(&el.to_string_lossy()); new_path.push_str(&el.to_string_lossy());
@ -516,14 +549,15 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
} else if self.show_index { } else if self.show_index {
Directory::new(self.directory.clone(), path) Directory::new(self.directory.clone(), path)
.respond_to(req.drop_state())? .respond_to(req.drop_state())?
.respond_to(req.drop_state()) .respond_to(req.drop_state())
} else { } else {
Ok(self.default.handle(req)) Ok(self.default.handle(req))
} }
} else { } else {
NamedFile::open(path)?.set_cpu_pool(self.cpu_pool.clone()) NamedFile::open(path)?
.set_cpu_pool(self.cpu_pool.clone())
.respond_to(req.drop_state())? .respond_to(req.drop_state())?
.respond_to(req.drop_state()) .respond_to(req.drop_state())
} }
} }
} }
@ -533,33 +567,49 @@ impl<S: 'static> Handler<S> for StaticFiles<S> {
mod tests { mod tests {
use super::*; use super::*;
use application::App; use application::App;
use test::{self, TestRequest};
use http::{header, Method, StatusCode}; use http::{header, Method, StatusCode};
use test::{self, TestRequest};
#[test] #[test]
fn test_named_file() { fn test_named_file() {
assert!(NamedFile::open("test--").is_err()); assert!(NamedFile::open("test--").is_err());
let mut file = NamedFile::open("Cargo.toml").unwrap() let mut file = NamedFile::open("Cargo.toml")
.unwrap()
.set_cpu_pool(CpuPool::new(1)); .set_cpu_pool(CpuPool::new(1));
{ file.file(); {
let _f: &File = &file; } file.file();
{ let _f: &mut File = &mut file; } let _f: &File = &file;
}
{
let _f: &mut File = &mut file;
}
let resp = file.respond_to(HttpRequest::default()).unwrap(); let resp = file.respond_to(HttpRequest::default()).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml") assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"text/x-toml"
)
} }
#[test] #[test]
fn test_named_file_status_code() { fn test_named_file_status_code() {
let mut file = NamedFile::open("Cargo.toml").unwrap() let mut file = NamedFile::open("Cargo.toml")
.unwrap()
.set_status_code(StatusCode::NOT_FOUND) .set_status_code(StatusCode::NOT_FOUND)
.set_cpu_pool(CpuPool::new(1)); .set_cpu_pool(CpuPool::new(1));
{ file.file(); {
let _f: &File = &file; } file.file();
{ let _f: &mut File = &mut file; } let _f: &File = &file;
}
{
let _f: &mut File = &mut file;
}
let resp = file.respond_to(HttpRequest::default()).unwrap(); let resp = file.respond_to(HttpRequest::default()).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml"); assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"text/x-toml"
);
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
} }
@ -584,13 +634,17 @@ mod tests {
fn test_static_files() { fn test_static_files() {
let mut st = StaticFiles::new(".").show_files_listing(); let mut st = StaticFiles::new(".").show_files_listing();
st.accessible = false; st.accessible = false;
let resp = st.handle(HttpRequest::default()).respond_to(HttpRequest::default()).unwrap(); let resp = st.handle(HttpRequest::default())
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response"); let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
st.accessible = true; st.accessible = true;
st.show_index = false; st.show_index = false;
let resp = st.handle(HttpRequest::default()).respond_to(HttpRequest::default()).unwrap(); let resp = st.handle(HttpRequest::default())
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response"); let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
@ -598,9 +652,14 @@ mod tests {
req.match_info_mut().add("tail", ""); req.match_info_mut().add("tail", "");
st.show_index = true; st.show_index = true;
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response"); let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/html; charset=utf-8"); assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"text/html; charset=utf-8"
);
assert!(resp.body().is_binary()); assert!(resp.body().is_binary());
assert!(format!("{:?}", resp.body()).contains("README.md")); assert!(format!("{:?}", resp.body()).contains("README.md"));
} }
@ -611,18 +670,28 @@ mod tests {
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.match_info_mut().add("tail", "guide"); req.match_info_mut().add("tail", "guide");
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response"); let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::FOUND); assert_eq!(resp.status(), StatusCode::FOUND);
assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/guide/index.html"); assert_eq!(
resp.headers().get(header::LOCATION).unwrap(),
"/guide/index.html"
);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.match_info_mut().add("tail", "guide/"); req.match_info_mut().add("tail", "guide/");
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response"); let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::FOUND); assert_eq!(resp.status(), StatusCode::FOUND);
assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/guide/index.html"); assert_eq!(
resp.headers().get(header::LOCATION).unwrap(),
"/guide/index.html"
);
} }
#[test] #[test]
@ -631,58 +700,87 @@ mod tests {
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.match_info_mut().add("tail", "tools/wsload"); req.match_info_mut().add("tail", "tools/wsload");
let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); let resp = st.handle(req)
.respond_to(HttpRequest::default())
.unwrap();
let resp = resp.as_response().expect("HTTP Response"); let resp = resp.as_response().expect("HTTP Response");
assert_eq!(resp.status(), StatusCode::FOUND); assert_eq!(resp.status(), StatusCode::FOUND);
assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/tools/wsload/Cargo.toml"); assert_eq!(
resp.headers().get(header::LOCATION).unwrap(),
"/tools/wsload/Cargo.toml"
);
} }
#[test] #[test]
fn integration_redirect_to_index_with_prefix() { fn integration_redirect_to_index_with_prefix() {
let mut srv = test::TestServer::with_factory( let mut srv = test::TestServer::with_factory(|| {
|| App::new() App::new()
.prefix("public") .prefix("public")
.handler("/", StaticFiles::new(".").index_file("Cargo.toml"))); .handler("/", StaticFiles::new(".").index_file("Cargo.toml"))
});
let request = srv.get().uri(srv.url("/public")).finish().unwrap(); let request = srv.get().uri(srv.url("/public")).finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND); assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap(); let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/public/Cargo.toml"); assert_eq!(loc, "/public/Cargo.toml");
let request = srv.get().uri(srv.url("/public/")).finish().unwrap(); let request = srv.get().uri(srv.url("/public/")).finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND); assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap(); let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/public/Cargo.toml"); assert_eq!(loc, "/public/Cargo.toml");
} }
#[test] #[test]
fn integration_redirect_to_index() { fn integration_redirect_to_index() {
let mut srv = test::TestServer::with_factory( let mut srv = test::TestServer::with_factory(|| {
|| App::new() App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
.handler("test", StaticFiles::new(".").index_file("Cargo.toml"))); });
let request = srv.get().uri(srv.url("/test")).finish().unwrap(); let request = srv.get().uri(srv.url("/test")).finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND); assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap(); let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/test/Cargo.toml"); assert_eq!(loc, "/test/Cargo.toml");
let request = srv.get().uri(srv.url("/test/")).finish().unwrap(); let request = srv.get().uri(srv.url("/test/")).finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::FOUND); assert_eq!(response.status(), StatusCode::FOUND);
let loc = response.headers().get(header::LOCATION).unwrap().to_str().unwrap(); let loc = response
.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(loc, "/test/Cargo.toml"); assert_eq!(loc, "/test/Cargo.toml");
} }
#[test] #[test]
fn integration_percent_encoded() { fn integration_percent_encoded() {
let mut srv = test::TestServer::with_factory( let mut srv = test::TestServer::with_factory(|| {
|| App::new() App::new().handler("test", StaticFiles::new(".").index_file("Cargo.toml"))
.handler("test", StaticFiles::new(".").index_file("Cargo.toml"))); });
let request = srv.get().uri(srv.url("/test/%43argo.toml")).finish().unwrap(); let request = srv.get()
.uri(srv.url("/test/%43argo.toml"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }

View File

@ -1,7 +1,7 @@
use std::ops::Deref;
use std::marker::PhantomData;
use futures::Poll; use futures::Poll;
use futures::future::{Future, FutureResult, ok, err}; use futures::future::{err, ok, Future, FutureResult};
use std::marker::PhantomData;
use std::ops::Deref;
use error::Error; use error::Error;
use httprequest::HttpRequest; use httprequest::HttpRequest;
@ -10,7 +10,6 @@ use httpresponse::HttpResponse;
/// Trait defines object that could be registered as route handler /// Trait defines object that could be registered as route handler
#[allow(unused_variables)] #[allow(unused_variables)]
pub trait Handler<S>: 'static { pub trait Handler<S>: 'static {
/// The type of value that handler will return. /// The type of value that handler will return.
type Result: Responder; type Result: Responder;
@ -35,13 +34,15 @@ pub trait Responder {
/// Trait implemented by types that can be extracted from request. /// Trait implemented by types that can be extracted from request.
/// ///
/// Types that implement this trait can be used with `Route::with()` method. /// Types that implement this trait can be used with `Route::with()` method.
pub trait FromRequest<S>: Sized where S: 'static pub trait FromRequest<S>: Sized
where
S: 'static,
{ {
/// Configuration for conversion process /// Configuration for conversion process
type Config: Default; type Config: Default;
/// Future that resolves to a Self /// Future that resolves to a Self
type Result: Future<Item=Self, Error=Error>; type Result: Future<Item = Self, Error = Error>;
/// Convert request to a Self /// Convert request to a Self
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result; fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result;
@ -83,7 +84,9 @@ pub enum Either<A, B> {
} }
impl<A, B> Responder for Either<A, B> impl<A, B> Responder for Either<A, B>
where A: Responder, B: Responder where
A: Responder,
B: Responder,
{ {
type Item = Reply; type Item = Reply;
type Error = Error; type Error = Error;
@ -103,8 +106,9 @@ impl<A, B> Responder for Either<A, B>
} }
impl<A, B, I, E> Future for Either<A, B> impl<A, B, I, E> Future for Either<A, B>
where A: Future<Item=I, Error=E>, where
B: Future<Item=I, Error=E>, A: Future<Item = I, Error = E>,
B: Future<Item = I, Error = E>,
{ {
type Item = I; type Item = I;
type Error = E; type Error = E;
@ -146,23 +150,25 @@ impl<A, B, I, E> Future for Either<A, B>
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
pub trait AsyncResponder<I, E>: Sized { pub trait AsyncResponder<I, E>: Sized {
fn responder(self) -> Box<Future<Item=I, Error=E>>; fn responder(self) -> Box<Future<Item = I, Error = E>>;
} }
impl<F, I, E> AsyncResponder<I, E> for F impl<F, I, E> AsyncResponder<I, E> for F
where F: Future<Item=I, Error=E> + 'static, where
I: Responder + 'static, F: Future<Item = I, Error = E> + 'static,
E: Into<Error> + 'static, I: Responder + 'static,
E: Into<Error> + 'static,
{ {
fn responder(self) -> Box<Future<Item=I, Error=E>> { fn responder(self) -> Box<Future<Item = I, Error = E>> {
Box::new(self) Box::new(self)
} }
} }
/// Handler<S> for Fn() /// Handler<S> for Fn()
impl<F, R, S> Handler<S> for F impl<F, R, S> Handler<S> for F
where F: Fn(HttpRequest<S>) -> R + 'static, where
R: Responder + 'static F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
{ {
type Result = R; type Result = R;
@ -176,15 +182,15 @@ pub struct Reply(ReplyItem);
pub(crate) enum ReplyItem { pub(crate) enum ReplyItem {
Message(HttpResponse), Message(HttpResponse),
Future(Box<Future<Item=HttpResponse, Error=Error>>), Future(Box<Future<Item = HttpResponse, Error = Error>>),
} }
impl Reply { impl Reply {
/// Create async response /// Create async response
#[inline] #[inline]
pub fn async<F>(fut: F) -> Reply pub fn async<F>(fut: F) -> Reply
where F: Future<Item=HttpResponse, Error=Error> + 'static where
F: Future<Item = HttpResponse, Error = Error> + 'static,
{ {
Reply(ReplyItem::Future(Box::new(fut))) Reply(ReplyItem::Future(Box::new(fut)))
} }
@ -229,15 +235,13 @@ impl Responder for HttpResponse {
} }
impl From<HttpResponse> for Reply { impl From<HttpResponse> for Reply {
#[inline] #[inline]
fn from(resp: HttpResponse) -> Reply { fn from(resp: HttpResponse) -> Reply {
Reply(ReplyItem::Message(resp)) Reply(ReplyItem::Message(resp))
} }
} }
impl<T: Responder, E: Into<Error>> Responder for Result<T, E> impl<T: Responder, E: Into<Error>> Responder for Result<T, E> {
{
type Item = <T as Responder>::Item; type Item = <T as Responder>::Item;
type Error = Error; type Error = Error;
@ -272,19 +276,20 @@ impl<E: Into<Error>> From<Result<HttpResponse, E>> for Reply {
} }
} }
impl From<Box<Future<Item=HttpResponse, Error=Error>>> for Reply { impl From<Box<Future<Item = HttpResponse, Error = Error>>> for Reply {
#[inline] #[inline]
fn from(fut: Box<Future<Item=HttpResponse, Error=Error>>) -> Reply { fn from(fut: Box<Future<Item = HttpResponse, Error = Error>>) -> Reply {
Reply(ReplyItem::Future(fut)) Reply(ReplyItem::Future(fut))
} }
} }
/// Convenience type alias /// Convenience type alias
pub type FutureResponse<I, E=Error> = Box<Future<Item=I, Error=E>>; pub type FutureResponse<I, E = Error> = Box<Future<Item = I, Error = E>>;
impl<I, E> Responder for Box<Future<Item=I, Error=E>> impl<I, E> Responder for Box<Future<Item = I, Error = E>>
where I: Responder + 'static, where
E: Into<Error> + 'static I: Responder + 'static,
E: Into<Error> + 'static,
{ {
type Item = Reply; type Item = Reply;
type Error = Error; type Error = Error;
@ -292,14 +297,12 @@ impl<I, E> Responder for Box<Future<Item=I, Error=E>>
#[inline] #[inline]
fn respond_to(self, req: HttpRequest) -> Result<Reply, Error> { fn respond_to(self, req: HttpRequest) -> Result<Reply, Error> {
let fut = self.map_err(|e| e.into()) let fut = self.map_err(|e| e.into())
.then(move |r| { .then(move |r| match r.respond_to(req) {
match r.respond_to(req) { Ok(reply) => match reply.into().0 {
Ok(reply) => match reply.into().0 { ReplyItem::Message(resp) => ok(resp),
ReplyItem::Message(resp) => ok(resp), _ => panic!("Nested async replies are not supported"),
_ => panic!("Nested async replies are not supported"), },
}, Err(e) => err(e),
Err(e) => err(e),
}
}); });
Ok(Reply::async(fut)) Ok(Reply::async(fut))
} }
@ -311,30 +314,35 @@ pub(crate) trait RouteHandler<S>: 'static {
} }
/// Route handler wrapper for Handler /// Route handler wrapper for Handler
pub(crate) pub(crate) struct WrapHandler<S, H, R>
struct WrapHandler<S, H, R> where
where H: Handler<S, Result=R>, H: Handler<S, Result = R>,
R: Responder, R: Responder,
S: 'static, S: 'static,
{ {
h: H, h: H,
s: PhantomData<S>, s: PhantomData<S>,
} }
impl<S, H, R> WrapHandler<S, H, R> impl<S, H, R> WrapHandler<S, H, R>
where H: Handler<S, Result=R>, where
R: Responder, H: Handler<S, Result = R>,
S: 'static, R: Responder,
S: 'static,
{ {
pub fn new(h: H) -> Self { pub fn new(h: H) -> Self {
WrapHandler{h, s: PhantomData} WrapHandler {
h,
s: PhantomData,
}
} }
} }
impl<S, H, R> RouteHandler<S> for WrapHandler<S, H, R> impl<S, H, R> RouteHandler<S> for WrapHandler<S, H, R>
where H: Handler<S, Result=R>, where
R: Responder + 'static, H: Handler<S, Result = R>,
S: 'static, R: Responder + 'static,
S: 'static,
{ {
fn handle(&mut self, req: HttpRequest<S>) -> Reply { fn handle(&mut self, req: HttpRequest<S>) -> Reply {
let req2 = req.drop_state(); let req2 = req.drop_state();
@ -346,50 +354,53 @@ impl<S, H, R> RouteHandler<S> for WrapHandler<S, H, R>
} }
/// Async route handler /// Async route handler
pub(crate) pub(crate) struct AsyncHandler<S, H, F, R, E>
struct AsyncHandler<S, H, F, R, E> where
where H: Fn(HttpRequest<S>) -> F + 'static, H: Fn(HttpRequest<S>) -> F + 'static,
F: Future<Item=R, Error=E> + 'static, F: Future<Item = R, Error = E> + 'static,
R: Responder + 'static, R: Responder + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
S: 'static, S: 'static,
{ {
h: Box<H>, h: Box<H>,
s: PhantomData<S>, s: PhantomData<S>,
} }
impl<S, H, F, R, E> AsyncHandler<S, H, F, R, E> impl<S, H, F, R, E> AsyncHandler<S, H, F, R, E>
where H: Fn(HttpRequest<S>) -> F + 'static, where
F: Future<Item=R, Error=E> + 'static, H: Fn(HttpRequest<S>) -> F + 'static,
R: Responder + 'static, F: Future<Item = R, Error = E> + 'static,
E: Into<Error> + 'static, R: Responder + 'static,
S: 'static, E: Into<Error> + 'static,
S: 'static,
{ {
pub fn new(h: H) -> Self { pub fn new(h: H) -> Self {
AsyncHandler{h: Box::new(h), s: PhantomData} AsyncHandler {
h: Box::new(h),
s: PhantomData,
}
} }
} }
impl<S, H, F, R, E> RouteHandler<S> for AsyncHandler<S, H, F, R, E> impl<S, H, F, R, E> RouteHandler<S> for AsyncHandler<S, H, F, R, E>
where H: Fn(HttpRequest<S>) -> F + 'static, where
F: Future<Item=R, Error=E> + 'static, H: Fn(HttpRequest<S>) -> F + 'static,
R: Responder + 'static, F: Future<Item = R, Error = E> + 'static,
E: Into<Error> + 'static, R: Responder + 'static,
S: 'static, E: Into<Error> + 'static,
S: 'static,
{ {
fn handle(&mut self, req: HttpRequest<S>) -> Reply { fn handle(&mut self, req: HttpRequest<S>) -> Reply {
let req2 = req.drop_state(); let req2 = req.drop_state();
let fut = (self.h)(req) let fut = (self.h)(req).map_err(|e| e.into()).then(move |r| {
.map_err(|e| e.into()) match r.respond_to(req2) {
.then(move |r| { Ok(reply) => match reply.into().0 {
match r.respond_to(req2) { ReplyItem::Message(resp) => ok(resp),
Ok(reply) => match reply.into().0 { _ => panic!("Nested async replies are not supported"),
ReplyItem::Message(resp) => ok(resp), },
_ => panic!("Nested async replies are not supported"), Err(e) => err(e),
}, }
Err(e) => err(e), });
}
});
Reply::async(fut) Reply::async(fut)
} }
} }
@ -426,7 +437,7 @@ impl<S, H, F, R, E> RouteHandler<S> for AsyncHandler<S, H, F, R, E>
/// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor /// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor
/// } /// }
/// ``` /// ```
pub struct State<S> (HttpRequest<S>); pub struct State<S>(HttpRequest<S>);
impl<S> Deref for State<S> { impl<S> Deref for State<S> {
type Target = S; type Target = S;
@ -436,8 +447,7 @@ impl<S> Deref for State<S> {
} }
} }
impl<S: 'static> FromRequest<S> for State<S> impl<S: 'static> FromRequest<S> for State<S> {
{
type Config = (); type Config = ();
type Result = FutureResult<Self, Error>; type Result = FutureResult<Self, Error>;

View File

@ -1,6 +1,6 @@
use mime::{self, Mime}; use header::{qitem, QualityItem};
use header::{QualityItem, qitem};
use http::header as http; use http::header as http;
use mime::{self, Mime};
header! { header! {
/// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2) /// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2)

View File

@ -1,4 +1,4 @@
use header::{ACCEPT_CHARSET, Charset, QualityItem}; use header::{Charset, QualityItem, ACCEPT_CHARSET};
header! { header! {
/// `Accept-Charset` header, defined in /// `Accept-Charset` header, defined in

View File

@ -1,5 +1,5 @@
use header::{QualityItem, ACCEPT_LANGUAGE};
use language_tags::LanguageTag; use language_tags::LanguageTag;
use header::{ACCEPT_LANGUAGE, QualityItem};
header! { header! {
/// `Accept-Language` header, defined in /// `Accept-Language` header, defined in

View File

@ -1,8 +1,8 @@
use header::{Header, IntoHeaderValue, Writer};
use header::{fmt_comma_delimited, from_comma_delimited};
use http::header;
use std::fmt::{self, Write}; use std::fmt::{self, Write};
use std::str::FromStr; use std::str::FromStr;
use http::header;
use header::{Header, IntoHeaderValue, Writer};
use header::{from_comma_delimited, fmt_comma_delimited};
/// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2) /// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2)
/// ///
@ -30,9 +30,7 @@ use header::{from_comma_delimited, fmt_comma_delimited};
/// use actix_web::http::header::{CacheControl, CacheDirective}; /// use actix_web::http::header::{CacheControl, CacheDirective};
/// ///
/// let mut builder = HttpResponse::Ok(); /// let mut builder = HttpResponse::Ok();
/// builder.set( /// builder.set(CacheControl(vec![CacheDirective::MaxAge(86400u32)]));
/// CacheControl(vec![CacheDirective::MaxAge(86400u32)])
/// );
/// ``` /// ```
/// ///
/// ```rust /// ```rust
@ -40,15 +38,12 @@ use header::{from_comma_delimited, fmt_comma_delimited};
/// use actix_web::http::header::{CacheControl, CacheDirective}; /// use actix_web::http::header::{CacheControl, CacheDirective};
/// ///
/// let mut builder = HttpResponse::Ok(); /// let mut builder = HttpResponse::Ok();
/// builder.set( /// builder.set(CacheControl(vec![
/// CacheControl(vec![ /// CacheDirective::NoCache,
/// CacheDirective::NoCache, /// CacheDirective::Private,
/// CacheDirective::Private, /// CacheDirective::MaxAge(360u32),
/// CacheDirective::MaxAge(360u32), /// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())),
/// CacheDirective::Extension("foo".to_owned(), /// ]));
/// Some("bar".to_owned())),
/// ])
/// );
/// ``` /// ```
#[derive(PartialEq, Clone, Debug)] #[derive(PartialEq, Clone, Debug)]
pub struct CacheControl(pub Vec<CacheDirective>); pub struct CacheControl(pub Vec<CacheDirective>);
@ -63,7 +58,8 @@ impl Header for CacheControl {
#[inline] #[inline]
fn parse<T>(msg: &T) -> Result<Self, ::error::ParseError> fn parse<T>(msg: &T) -> Result<Self, ::error::ParseError>
where T: ::HttpMessage where
T: ::HttpMessage,
{ {
let directives = from_comma_delimited(msg.headers().get_all(Self::name()))?; let directives = from_comma_delimited(msg.headers().get_all(Self::name()))?;
if !directives.is_empty() { if !directives.is_empty() {
@ -123,32 +119,36 @@ pub enum CacheDirective {
SMaxAge(u32), SMaxAge(u32),
/// Extension directives. Optionally include an argument. /// Extension directives. Optionally include an argument.
Extension(String, Option<String>) Extension(String, Option<String>),
} }
impl fmt::Display for CacheDirective { impl fmt::Display for CacheDirective {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::CacheDirective::*; use self::CacheDirective::*;
fmt::Display::fmt(match *self { fmt::Display::fmt(
NoCache => "no-cache", match *self {
NoStore => "no-store", NoCache => "no-cache",
NoTransform => "no-transform", NoStore => "no-store",
OnlyIfCached => "only-if-cached", NoTransform => "no-transform",
OnlyIfCached => "only-if-cached",
MaxAge(secs) => return write!(f, "max-age={}", secs), MaxAge(secs) => return write!(f, "max-age={}", secs),
MaxStale(secs) => return write!(f, "max-stale={}", secs), MaxStale(secs) => return write!(f, "max-stale={}", secs),
MinFresh(secs) => return write!(f, "min-fresh={}", secs), MinFresh(secs) => return write!(f, "min-fresh={}", secs),
MustRevalidate => "must-revalidate", MustRevalidate => "must-revalidate",
Public => "public", Public => "public",
Private => "private", Private => "private",
ProxyRevalidate => "proxy-revalidate", ProxyRevalidate => "proxy-revalidate",
SMaxAge(secs) => return write!(f, "s-maxage={}", secs), SMaxAge(secs) => return write!(f, "s-maxage={}", secs),
Extension(ref name, None) => &name[..], Extension(ref name, None) => &name[..],
Extension(ref name, Some(ref arg)) => return write!(f, "{}={}", name, arg), Extension(ref name, Some(ref arg)) => {
return write!(f, "{}={}", name, arg)
}, f) }
},
f,
)
} }
} }
@ -167,16 +167,20 @@ impl FromStr for CacheDirective {
"proxy-revalidate" => Ok(ProxyRevalidate), "proxy-revalidate" => Ok(ProxyRevalidate),
"" => Err(None), "" => Err(None),
_ => match s.find('=') { _ => match s.find('=') {
Some(idx) if idx+1 < s.len() => match (&s[..idx], (&s[idx+1..]).trim_matches('"')) { Some(idx) if idx + 1 < s.len() => {
("max-age" , secs) => secs.parse().map(MaxAge).map_err(Some), match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) {
("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some), ("max-age", secs) => secs.parse().map(MaxAge).map_err(Some),
("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some), ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some),
("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some), ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some),
(left, right) => Ok(Extension(left.to_owned(), Some(right.to_owned()))) ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some),
}, (left, right) => {
Ok(Extension(left.to_owned(), Some(right.to_owned())))
}
}
}
Some(_) => Err(None), Some(_) => Err(None),
None => Ok(Extension(s.to_owned(), None)) None => Ok(Extension(s.to_owned(), None)),
} },
} }
} }
} }
@ -189,38 +193,56 @@ mod tests {
#[test] #[test]
fn test_parse_multiple_headers() { fn test_parse_multiple_headers() {
let req = TestRequest::with_header( let req = TestRequest::with_header(header::CACHE_CONTROL, "no-cache, private")
header::CACHE_CONTROL, "no-cache, private").finish(); .finish();
let cache = Header::parse(&req); let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::NoCache, assert_eq!(
CacheDirective::Private]))) cache.ok(),
Some(CacheControl(vec![
CacheDirective::NoCache,
CacheDirective::Private,
]))
)
} }
#[test] #[test]
fn test_parse_argument() { fn test_parse_argument() {
let req = TestRequest::with_header( let req =
header::CACHE_CONTROL, "max-age=100, private").finish(); TestRequest::with_header(header::CACHE_CONTROL, "max-age=100, private")
.finish();
let cache = Header::parse(&req); let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::MaxAge(100), assert_eq!(
CacheDirective::Private]))) cache.ok(),
Some(CacheControl(vec![
CacheDirective::MaxAge(100),
CacheDirective::Private,
]))
)
} }
#[test] #[test]
fn test_parse_quote_form() { fn test_parse_quote_form() {
let req = TestRequest::with_header( let req =
header::CACHE_CONTROL, "max-age=\"200\"").finish(); TestRequest::with_header(header::CACHE_CONTROL, "max-age=\"200\"").finish();
let cache = Header::parse(&req); let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::MaxAge(200)]))) assert_eq!(
cache.ok(),
Some(CacheControl(vec![CacheDirective::MaxAge(200)]))
)
} }
#[test] #[test]
fn test_parse_extension() { fn test_parse_extension() {
let req = TestRequest::with_header( let req =
header::CACHE_CONTROL, "foo, bar=baz").finish(); TestRequest::with_header(header::CACHE_CONTROL, "foo, bar=baz").finish();
let cache = Header::parse(&req); let cache = Header::parse(&req);
assert_eq!(cache.ok(), Some(CacheControl(vec![ assert_eq!(
CacheDirective::Extension("foo".to_owned(), None), cache.ok(),
CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned()))]))) Some(CacheControl(vec![
CacheDirective::Extension("foo".to_owned(), None),
CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())),
]))
)
} }
#[test] #[test]

View File

@ -1,21 +1,21 @@
use header::{QualityItem, CONTENT_LANGUAGE};
use language_tags::LanguageTag; use language_tags::LanguageTag;
use header::{CONTENT_LANGUAGE, QualityItem};
header! { header! {
/// `Content-Language` header, defined in /// `Content-Language` header, defined in
/// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2) /// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2)
/// ///
/// The `Content-Language` header field describes the natural language(s) /// The `Content-Language` header field describes the natural language(s)
/// of the intended audience for the representation. Note that this /// of the intended audience for the representation. Note that this
/// might not be equivalent to all the languages used within the /// might not be equivalent to all the languages used within the
/// representation. /// representation.
/// ///
/// # ABNF /// # ABNF
/// ///
/// ```text /// ```text
/// Content-Language = 1#language-tag /// Content-Language = 1#language-tag
/// ``` /// ```
/// ///
/// # Example values /// # Example values
/// ///
/// * `da` /// * `da`
@ -28,7 +28,7 @@ header! {
/// # #[macro_use] extern crate language_tags; /// # #[macro_use] extern crate language_tags;
/// use actix_web::HttpResponse; /// use actix_web::HttpResponse;
/// # use actix_web::http::header::{ContentLanguage, qitem}; /// # use actix_web::http::header::{ContentLanguage, qitem};
/// # /// #
/// # fn main() { /// # fn main() {
/// let mut builder = HttpResponse::Ok(); /// let mut builder = HttpResponse::Ok();
/// builder.set( /// builder.set(
@ -46,7 +46,7 @@ header! {
/// # use actix_web::http::header::{ContentLanguage, qitem}; /// # use actix_web::http::header::{ContentLanguage, qitem};
/// # /// #
/// # fn main() { /// # fn main() {
/// ///
/// let mut builder = HttpResponse::Ok(); /// let mut builder = HttpResponse::Ok();
/// builder.set( /// builder.set(
/// ContentLanguage(vec![ /// ContentLanguage(vec![

View File

@ -1,8 +1,8 @@
use error::ParseError;
use header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer,
CONTENT_RANGE};
use std::fmt::{self, Display, Write}; use std::fmt::{self, Display, Write};
use std::str::FromStr; use std::str::FromStr;
use error::ParseError;
use header::{IntoHeaderValue, Writer,
HeaderValue, InvalidHeaderValueBytes, CONTENT_RANGE};
header! { header! {
/// `Content-Range` header, defined in /// `Content-Range` header, defined in
@ -69,7 +69,6 @@ header! {
} }
} }
/// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2) /// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2)
/// ///
/// # ABNF /// # ABNF
@ -99,7 +98,7 @@ pub enum ContentRangeSpec {
range: Option<(u64, u64)>, range: Option<(u64, u64)>,
/// Total length of the instance, can be omitted if unknown /// Total length of the instance, can be omitted if unknown
instance_length: Option<u64> instance_length: Option<u64>,
}, },
/// Custom range, with unit not registered at IANA /// Custom range, with unit not registered at IANA
@ -108,15 +107,15 @@ pub enum ContentRangeSpec {
unit: String, unit: String,
/// other-range-resp /// other-range-resp
resp: String resp: String,
} },
} }
fn split_in_two(s: &str, separator: char) -> Option<(&str, &str)> { fn split_in_two(s: &str, separator: char) -> Option<(&str, &str)> {
let mut iter = s.splitn(2, separator); let mut iter = s.splitn(2, separator);
match (iter.next(), iter.next()) { match (iter.next(), iter.next()) {
(Some(a), Some(b)) => Some((a, b)), (Some(a), Some(b)) => Some((a, b)),
_ => None _ => None,
} }
} }
@ -126,40 +125,40 @@ impl FromStr for ContentRangeSpec {
fn from_str(s: &str) -> Result<Self, ParseError> { fn from_str(s: &str) -> Result<Self, ParseError> {
let res = match split_in_two(s, ' ') { let res = match split_in_two(s, ' ') {
Some(("bytes", resp)) => { Some(("bytes", resp)) => {
let (range, instance_length) = split_in_two( let (range, instance_length) =
resp, '/').ok_or(ParseError::Header)?; split_in_two(resp, '/').ok_or(ParseError::Header)?;
let instance_length = if instance_length == "*" { let instance_length = if instance_length == "*" {
None None
} else { } else {
Some(instance_length.parse() Some(instance_length
.map_err(|_| ParseError::Header)?) .parse()
.map_err(|_| ParseError::Header)?)
}; };
let range = if range == "*" { let range = if range == "*" {
None None
} else { } else {
let (first_byte, last_byte) = split_in_two( let (first_byte, last_byte) =
range, '-').ok_or(ParseError::Header)?; split_in_two(range, '-').ok_or(ParseError::Header)?;
let first_byte = first_byte.parse() let first_byte = first_byte.parse().map_err(|_| ParseError::Header)?;
.map_err(|_| ParseError::Header)?; let last_byte = last_byte.parse().map_err(|_| ParseError::Header)?;
let last_byte = last_byte.parse()
.map_err(|_| ParseError::Header)?;
if last_byte < first_byte { if last_byte < first_byte {
return Err(ParseError::Header); return Err(ParseError::Header);
} }
Some((first_byte, last_byte)) Some((first_byte, last_byte))
}; };
ContentRangeSpec::Bytes {range, instance_length} ContentRangeSpec::Bytes {
} range,
Some((unit, resp)) => { instance_length,
ContentRangeSpec::Unregistered {
unit: unit.to_owned(),
resp: resp.to_owned()
} }
} }
_ => return Err(ParseError::Header) Some((unit, resp)) => ContentRangeSpec::Unregistered {
unit: unit.to_owned(),
resp: resp.to_owned(),
},
_ => return Err(ParseError::Header),
}; };
Ok(res) Ok(res)
} }
@ -168,12 +167,15 @@ impl FromStr for ContentRangeSpec {
impl Display for ContentRangeSpec { impl Display for ContentRangeSpec {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
ContentRangeSpec::Bytes { range, instance_length } => { ContentRangeSpec::Bytes {
range,
instance_length,
} => {
try!(f.write_str("bytes ")); try!(f.write_str("bytes "));
match range { match range {
Some((first_byte, last_byte)) => { Some((first_byte, last_byte)) => {
try!(write!(f, "{}-{}", first_byte, last_byte)); try!(write!(f, "{}-{}", first_byte, last_byte));
}, }
None => { None => {
try!(f.write_str("*")); try!(f.write_str("*"));
} }
@ -185,7 +187,10 @@ impl Display for ContentRangeSpec {
f.write_str("*") f.write_str("*")
} }
} }
ContentRangeSpec::Unregistered { ref unit, ref resp } => { ContentRangeSpec::Unregistered {
ref unit,
ref resp,
} => {
try!(f.write_str(unit)); try!(f.write_str(unit));
try!(f.write_str(" ")); try!(f.write_str(" "));
f.write_str(resp) f.write_str(resp)

View File

@ -1,6 +1,5 @@
use mime::{self, Mime};
use header::CONTENT_TYPE; use header::CONTENT_TYPE;
use mime::{self, Mime};
header! { header! {
/// `Content-Type` header, defined in /// `Content-Type` header, defined in
@ -68,13 +67,15 @@ header! {
} }
impl ContentType { impl ContentType {
/// A constructor to easily create a `Content-Type: application/json` header. /// A constructor to easily create a `Content-Type: application/json`
/// header.
#[inline] #[inline]
pub fn json() -> ContentType { pub fn json() -> ContentType {
ContentType(mime::APPLICATION_JSON) ContentType(mime::APPLICATION_JSON)
} }
/// A constructor to easily create a `Content-Type: text/plain; charset=utf-8` header. /// A constructor to easily create a `Content-Type: text/plain;
/// charset=utf-8` header.
#[inline] #[inline]
pub fn plaintext() -> ContentType { pub fn plaintext() -> ContentType {
ContentType(mime::TEXT_PLAIN_UTF_8) ContentType(mime::TEXT_PLAIN_UTF_8)
@ -92,7 +93,8 @@ impl ContentType {
ContentType(mime::TEXT_XML) ContentType(mime::TEXT_XML)
} }
/// A constructor to easily create a `Content-Type: application/www-form-url-encoded` header. /// A constructor to easily create a `Content-Type:
/// application/www-form-url-encoded` header.
#[inline] #[inline]
pub fn form_url_encoded() -> ContentType { pub fn form_url_encoded() -> ContentType {
ContentType(mime::APPLICATION_WWW_FORM_URLENCODED) ContentType(mime::APPLICATION_WWW_FORM_URLENCODED)
@ -109,7 +111,8 @@ impl ContentType {
ContentType(mime::IMAGE_PNG) ContentType(mime::IMAGE_PNG)
} }
/// A constructor to easily create a `Content-Type: application/octet-stream` header. /// A constructor to easily create a `Content-Type:
/// application/octet-stream` header.
#[inline] #[inline]
pub fn octet_stream() -> ContentType { pub fn octet_stream() -> ContentType {
ContentType(mime::APPLICATION_OCTET_STREAM) ContentType(mime::APPLICATION_OCTET_STREAM)

View File

@ -1,6 +1,5 @@
use header::{HttpDate, DATE};
use std::time::SystemTime; use std::time::SystemTime;
use header::{DATE, HttpDate};
header! { header! {
/// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2) /// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2)

View File

@ -1,4 +1,4 @@
use header::{ETAG, EntityTag}; use header::{EntityTag, ETAG};
header! { header! {
/// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3) /// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3)

View File

@ -1,4 +1,4 @@
use header::{EXPIRES, HttpDate}; use header::{HttpDate, EXPIRES};
header! { header! {
/// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3) /// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3)

View File

@ -1,4 +1,4 @@
use header::{IF_MATCH, EntityTag}; use header::{EntityTag, IF_MATCH};
header! { header! {
/// `If-Match` header, defined in /// `If-Match` header, defined in

View File

@ -1,4 +1,4 @@
use header::{IF_MODIFIED_SINCE, HttpDate}; use header::{HttpDate, IF_MODIFIED_SINCE};
header! { header! {
/// `If-Modified-Since` header, defined in /// `If-Modified-Since` header, defined in

View File

@ -1,4 +1,4 @@
use header::{IF_NONE_MATCH, EntityTag}; use header::{EntityTag, IF_NONE_MATCH};
header! { header! {
/// `If-None-Match` header, defined in /// `If-None-Match` header, defined in
@ -66,8 +66,8 @@ header! {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::IfNoneMatch; use super::IfNoneMatch;
use header::{EntityTag, Header, IF_NONE_MATCH};
use test::TestRequest; use test::TestRequest;
use header::{IF_NONE_MATCH, Header, EntityTag};
#[test] #[test]
fn test_if_none_match() { fn test_if_none_match() {
@ -77,8 +77,9 @@ mod tests {
if_none_match = Header::parse(&req); if_none_match = Header::parse(&req);
assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Any)); assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Any));
let req = TestRequest::with_header( let req =
IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..]).finish(); TestRequest::with_header(IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..])
.finish();
if_none_match = Header::parse(&req); if_none_match = Header::parse(&req);
let mut entities: Vec<EntityTag> = Vec::new(); let mut entities: Vec<EntityTag> = Vec::new();

View File

@ -1,10 +1,10 @@
use std::fmt::{self, Display, Write};
use error::ParseError; use error::ParseError;
use httpmessage::HttpMessage;
use http::header;
use header::from_one_raw_str; use header::from_one_raw_str;
use header::{IntoHeaderValue, Header, HeaderName, HeaderValue, use header::{EntityTag, Header, HeaderName, HeaderValue, HttpDate, IntoHeaderValue,
EntityTag, HttpDate, Writer, InvalidHeaderValueBytes}; InvalidHeaderValueBytes, Writer};
use http::header;
use httpmessage::HttpMessage;
use std::fmt::{self, Display, Write};
/// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) /// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2)
/// ///
@ -36,16 +36,19 @@ use header::{IntoHeaderValue, Header, HeaderName, HeaderValue,
/// ///
/// ```rust /// ```rust
/// use actix_web::HttpResponse; /// use actix_web::HttpResponse;
/// use actix_web::http::header::{IfRange, EntityTag}; /// use actix_web::http::header::{EntityTag, IfRange};
/// ///
/// let mut builder = HttpResponse::Ok(); /// let mut builder = HttpResponse::Ok();
/// builder.set(IfRange::EntityTag(EntityTag::new(false, "xyzzy".to_owned()))); /// builder.set(IfRange::EntityTag(EntityTag::new(
/// false,
/// "xyzzy".to_owned(),
/// )));
/// ``` /// ```
/// ///
/// ```rust /// ```rust
/// use actix_web::HttpResponse; /// use actix_web::HttpResponse;
/// use actix_web::http::header::IfRange; /// use actix_web::http::header::IfRange;
/// use std::time::{SystemTime, Duration}; /// use std::time::{Duration, SystemTime};
/// ///
/// let mut builder = HttpResponse::Ok(); /// let mut builder = HttpResponse::Ok();
/// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24);
@ -64,7 +67,9 @@ impl Header for IfRange {
header::IF_RANGE header::IF_RANGE
} }
#[inline] #[inline]
fn parse<T>(msg: &T) -> Result<Self, ParseError> where T: HttpMessage fn parse<T>(msg: &T) -> Result<Self, ParseError>
where
T: HttpMessage,
{ {
let etag: Result<EntityTag, _> = let etag: Result<EntityTag, _> =
from_one_raw_str(msg.headers().get(header::IF_RANGE)); from_one_raw_str(msg.headers().get(header::IF_RANGE));
@ -99,12 +104,11 @@ impl IntoHeaderValue for IfRange {
} }
} }
#[cfg(test)] #[cfg(test)]
mod test_if_range { mod test_if_range {
use std::str;
use header::*;
use super::IfRange as HeaderField; use super::IfRange as HeaderField;
use header::*;
use std::str;
test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]);
test_header!(test2, vec![b"\"xyzzy\""]); test_header!(test2, vec![b"\"xyzzy\""]);
test_header!(test3, vec![b"this-is-invalid"], None::<IfRange>); test_header!(test3, vec![b"this-is-invalid"], None::<IfRange>);

View File

@ -1,4 +1,4 @@
use header::{IF_UNMODIFIED_SINCE, HttpDate}; use header::{HttpDate, IF_UNMODIFIED_SINCE};
header! { header! {
/// `If-Unmodified-Since` header, defined in /// `If-Unmodified-Since` header, defined in

View File

@ -1,4 +1,4 @@
use header::{LAST_MODIFIED, HttpDate}; use header::{HttpDate, LAST_MODIFIED};
header! { header! {
/// `Last-Modified` header, defined in /// `Last-Modified` header, defined in

View File

@ -5,6 +5,7 @@
//! Several header fields use MIME values for their contents. Keeping with the //! Several header fields use MIME values for their contents. Keeping with the
//! strongly-typed theme, the [mime](https://docs.rs/mime) crate //! strongly-typed theme, the [mime](https://docs.rs/mime) crate
//! is used, such as `ContentType(pub Mime)`. //! is used, such as `ContentType(pub Mime)`.
#![cfg_attr(rustfmt, rustfmt_skip)]
pub use self::accept_charset::AcceptCharset; pub use self::accept_charset::AcceptCharset;
//pub use self::accept_encoding::AcceptEncoding; //pub use self::accept_encoding::AcceptEncoding;

View File

@ -1,8 +1,8 @@
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::str::FromStr; use std::str::FromStr;
use header::parsing::from_one_raw_str;
use header::{Header, Raw}; use header::{Header, Raw};
use header::parsing::{from_one_raw_str};
/// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1) /// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1)
/// ///
@ -65,7 +65,7 @@ pub enum Range {
Bytes(Vec<ByteRangeSpec>), Bytes(Vec<ByteRangeSpec>),
/// Custom range, with unit not registered at IANA /// Custom range, with unit not registered at IANA
/// (`other-range-unit`: String , `other-range-set`: String) /// (`other-range-unit`: String , `other-range-set`: String)
Unregistered(String, String) Unregistered(String, String),
} }
/// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`. /// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`.
@ -77,25 +77,25 @@ pub enum ByteRangeSpec {
/// Get all bytes starting from x ("x-") /// Get all bytes starting from x ("x-")
AllFrom(u64), AllFrom(u64),
/// Get last x bytes ("-x") /// Get last x bytes ("-x")
Last(u64) Last(u64),
} }
impl ByteRangeSpec { impl ByteRangeSpec {
/// Given the full length of the entity, attempt to normalize the byte range /// Given the full length of the entity, attempt to normalize the byte range
/// into an satisfiable end-inclusive (from, to) range. /// into an satisfiable end-inclusive (from, to) range.
/// ///
/// The resulting range is guaranteed to be a satisfiable range within the bounds /// The resulting range is guaranteed to be a satisfiable range within the
/// of `0 <= from <= to < full_length`. /// bounds of `0 <= from <= to < full_length`.
/// ///
/// If the byte range is deemed unsatisfiable, `None` is returned. /// If the byte range is deemed unsatisfiable, `None` is returned.
/// An unsatisfiable range is generally cause for a server to either reject /// An unsatisfiable range is generally cause for a server to either reject
/// the client request with a `416 Range Not Satisfiable` status code, or to /// the client request with a `416 Range Not Satisfiable` status code, or to
/// simply ignore the range header and serve the full entity using a `200 OK` /// simply ignore the range header and serve the full entity using a `200
/// status code. /// OK` status code.
/// ///
/// This function closely follows [RFC 7233][1] section 2.1. /// This function closely follows [RFC 7233][1] section 2.1.
/// As such, it considers ranges to be satisfiable if they meet the following /// As such, it considers ranges to be satisfiable if they meet the
/// conditions: /// following conditions:
/// ///
/// > If a valid byte-range-set includes at least one byte-range-spec with /// > If a valid byte-range-set includes at least one byte-range-spec with
/// a first-byte-pos that is less than the current length of the /// a first-byte-pos that is less than the current length of the
@ -125,14 +125,14 @@ impl ByteRangeSpec {
} else { } else {
None None
} }
}, }
&ByteRangeSpec::AllFrom(from) => { &ByteRangeSpec::AllFrom(from) => {
if from < full_length { if from < full_length {
Some((from, full_length - 1)) Some((from, full_length - 1))
} else { } else {
None None
} }
}, }
&ByteRangeSpec::Last(last) => { &ByteRangeSpec::Last(last) => {
if last > 0 { if last > 0 {
// From the RFC: If the selected representation is shorter // From the RFC: If the selected representation is shorter
@ -160,11 +160,15 @@ impl Range {
/// Get byte range header with multiple subranges /// Get byte range header with multiple subranges
/// ("bytes=from1-to1,from2-to2,fromX-toX") /// ("bytes=from1-to1,from2-to2,fromX-toX")
pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range { pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range {
Range::Bytes(ranges.iter().map(|r| ByteRangeSpec::FromTo(r.0, r.1)).collect()) Range::Bytes(
ranges
.iter()
.map(|r| ByteRangeSpec::FromTo(r.0, r.1))
.collect(),
)
} }
} }
impl fmt::Display for ByteRangeSpec { impl fmt::Display for ByteRangeSpec {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
@ -175,7 +179,6 @@ impl fmt::Display for ByteRangeSpec {
} }
} }
impl fmt::Display for Range { impl fmt::Display for Range {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
@ -189,10 +192,10 @@ impl fmt::Display for Range {
try!(Display::fmt(range, f)); try!(Display::fmt(range, f));
} }
Ok(()) Ok(())
}, }
Range::Unregistered(ref unit, ref range_str) => { Range::Unregistered(ref unit, ref range_str) => {
write!(f, "{}={}", unit, range_str) write!(f, "{}={}", unit, range_str)
}, }
} }
} }
} }
@ -211,11 +214,10 @@ impl FromStr for Range {
} }
Ok(Range::Bytes(ranges)) Ok(Range::Bytes(ranges))
} }
(Some(unit), Some(range_str)) if unit != "" && range_str != "" => { (Some(unit), Some(range_str)) if unit != "" && range_str != "" => Ok(
Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned())) Range::Unregistered(unit.to_owned(), range_str.to_owned()),
),
}, _ => Err(::Error::Header),
_ => Err(::Error::Header)
} }
} }
} }
@ -227,19 +229,20 @@ impl FromStr for ByteRangeSpec {
let mut parts = s.splitn(2, '-'); let mut parts = s.splitn(2, '-');
match (parts.next(), parts.next()) { match (parts.next(), parts.next()) {
(Some(""), Some(end)) => { (Some(""), Some(end)) => end.parse()
end.parse().or(Err(::Error::Header)).map(ByteRangeSpec::Last) .or(Err(::Error::Header))
}, .map(ByteRangeSpec::Last),
(Some(start), Some("")) => { (Some(start), Some("")) => start
start.parse().or(Err(::Error::Header)).map(ByteRangeSpec::AllFrom) .parse()
}, .or(Err(::Error::Header))
(Some(start), Some(end)) => { .map(ByteRangeSpec::AllFrom),
match (start.parse(), end.parse()) { (Some(start), Some(end)) => match (start.parse(), end.parse()) {
(Ok(start), Ok(end)) if start <= end => Ok(ByteRangeSpec::FromTo(start, end)), (Ok(start), Ok(end)) if start <= end => {
_ => Err(::Error::Header) Ok(ByteRangeSpec::FromTo(start, end))
} }
_ => Err(::Error::Header),
}, },
_ => Err(::Error::Header) _ => Err(::Error::Header),
} }
} }
} }
@ -248,14 +251,13 @@ fn from_comma_delimited<T: FromStr>(s: &str) -> Vec<T> {
s.split(',') s.split(',')
.filter_map(|x| match x.trim() { .filter_map(|x| match x.trim() {
"" => None, "" => None,
y => Some(y) y => Some(y),
}) })
.filter_map(|x| x.parse().ok()) .filter_map(|x| x.parse().ok())
.collect() .collect()
} }
impl Header for Range { impl Header for Range {
fn header_name() -> &'static str { fn header_name() -> &'static str {
static NAME: &'static str = "Range"; static NAME: &'static str = "Range";
NAME NAME
@ -268,51 +270,52 @@ impl Header for Range {
fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result { fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result {
f.fmt_line(self) f.fmt_line(self)
} }
} }
#[test] #[test]
fn test_parse_bytes_range_valid() { fn test_parse_bytes_range_valid() {
let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap(); let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap();
let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap(); let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap();
let r3 = Range::bytes(1, 100); let r3 = Range::bytes(1, 100);
assert_eq!(r, r2); assert_eq!(r, r2);
assert_eq!(r2, r3); assert_eq!(r2, r3);
let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap(); let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap();
let r2: Range = Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap(); let r2: Range =
let r3 = Range::Bytes( Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap();
vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::AllFrom(200)] let r3 = Range::Bytes(vec![
); ByteRangeSpec::FromTo(1, 100),
ByteRangeSpec::AllFrom(200),
]);
assert_eq!(r, r2); assert_eq!(r, r2);
assert_eq!(r2, r3); assert_eq!(r2, r3);
let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap(); let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap();
let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap(); let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap();
let r3 = Range::Bytes( let r3 = Range::Bytes(vec![
vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::Last(100)] ByteRangeSpec::FromTo(1, 100),
); ByteRangeSpec::Last(100),
]);
assert_eq!(r, r2); assert_eq!(r, r2);
assert_eq!(r2, r3); assert_eq!(r2, r3);
let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned());
assert_eq!(r, r2); assert_eq!(r, r2);
} }
#[test] #[test]
fn test_parse_unregistered_range_valid() { fn test_parse_unregistered_range_valid() {
let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned());
assert_eq!(r, r2); assert_eq!(r, r2);
let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap(); let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned());
assert_eq!(r, r2); assert_eq!(r, r2);
let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap(); let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap();
let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned());
assert_eq!(r, r2); assert_eq!(r, r2);
} }
@ -346,10 +349,10 @@ fn test_fmt() {
let mut headers = Headers::new(); let mut headers = Headers::new();
headers.set( headers.set(Range::Bytes(vec![
Range::Bytes( ByteRangeSpec::FromTo(0, 1000),
vec![ByteRangeSpec::FromTo(0, 1000), ByteRangeSpec::AllFrom(2000)] ByteRangeSpec::AllFrom(2000),
)); ]));
assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n"); assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n");
headers.clear(); headers.clear();
@ -358,30 +361,74 @@ fn test_fmt() {
assert_eq!(&headers.to_string(), "Range: bytes=\r\n"); assert_eq!(&headers.to_string(), "Range: bytes=\r\n");
headers.clear(); headers.clear();
headers.set(Range::Unregistered("custom".to_owned(), "1-xxx".to_owned())); headers.set(Range::Unregistered(
"custom".to_owned(),
"1-xxx".to_owned(),
));
assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n"); assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n");
} }
#[test] #[test]
fn test_byte_range_spec_to_satisfiable_range() { fn test_byte_range_spec_to_satisfiable_range() {
assert_eq!(Some((0, 0)), ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3)); assert_eq!(
assert_eq!(Some((1, 2)), ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3)); Some((0, 0)),
assert_eq!(Some((1, 2)), ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3)); ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3)
assert_eq!(None, ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3)); );
assert_eq!(None, ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)); assert_eq!(
assert_eq!(None, ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)); Some((1, 2)),
ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3)
);
assert_eq!(
Some((1, 2)),
ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)
);
assert_eq!(Some((0, 2)), ByteRangeSpec::AllFrom(0).to_satisfiable_range(3)); assert_eq!(
assert_eq!(Some((2, 2)), ByteRangeSpec::AllFrom(2).to_satisfiable_range(3)); Some((0, 2)),
assert_eq!(None, ByteRangeSpec::AllFrom(3).to_satisfiable_range(3)); ByteRangeSpec::AllFrom(0).to_satisfiable_range(3)
assert_eq!(None, ByteRangeSpec::AllFrom(5).to_satisfiable_range(3)); );
assert_eq!(None, ByteRangeSpec::AllFrom(0).to_satisfiable_range(0)); assert_eq!(
Some((2, 2)),
ByteRangeSpec::AllFrom(2).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::AllFrom(3).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::AllFrom(5).to_satisfiable_range(3)
);
assert_eq!(
None,
ByteRangeSpec::AllFrom(0).to_satisfiable_range(0)
);
assert_eq!(Some((1, 2)), ByteRangeSpec::Last(2).to_satisfiable_range(3)); assert_eq!(
assert_eq!(Some((2, 2)), ByteRangeSpec::Last(1).to_satisfiable_range(3)); Some((1, 2)),
assert_eq!(Some((0, 2)), ByteRangeSpec::Last(5).to_satisfiable_range(3)); ByteRangeSpec::Last(2).to_satisfiable_range(3)
);
assert_eq!(
Some((2, 2)),
ByteRangeSpec::Last(1).to_satisfiable_range(3)
);
assert_eq!(
Some((0, 2)),
ByteRangeSpec::Last(5).to_satisfiable_range(3)
);
assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3)); assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3));
assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0)); assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0));
} }

View File

@ -5,9 +5,9 @@ use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use modhttp::{Error as HttpError};
use modhttp::header::GetAll;
use mime::Mime; use mime::Mime;
use modhttp::Error as HttpError;
use modhttp::header::GetAll;
pub use modhttp::header::*; pub use modhttp::header::*;
@ -21,11 +21,12 @@ pub use self::common::*;
#[doc(hidden)] #[doc(hidden)]
pub use self::shared::*; pub use self::shared::*;
#[doc(hidden)] #[doc(hidden)]
/// A trait for any object that will represent a header field and value. /// A trait for any object that will represent a header field and value.
pub trait Header where Self: IntoHeaderValue { pub trait Header
where
Self: IntoHeaderValue,
{
/// Returns the name of the header field /// Returns the name of the header field
fn name() -> HeaderName; fn name() -> HeaderName;
@ -112,7 +113,7 @@ pub enum ContentEncoding {
/// Automatically select encoding based on encoding negotiation /// Automatically select encoding based on encoding negotiation
Auto, Auto,
/// A format using the Brotli algorithm /// A format using the Brotli algorithm
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
Br, Br,
/// A format using the zlib structure with deflate algorithm /// A format using the zlib structure with deflate algorithm
Deflate, Deflate,
@ -123,19 +124,18 @@ pub enum ContentEncoding {
} }
impl ContentEncoding { impl ContentEncoding {
#[inline] #[inline]
pub fn is_compression(&self) -> bool { pub fn is_compression(&self) -> bool {
match *self { match *self {
ContentEncoding::Identity | ContentEncoding::Auto => false, ContentEncoding::Identity | ContentEncoding::Auto => false,
_ => true _ => true,
} }
} }
#[inline] #[inline]
pub fn as_str(&self) -> &'static str { pub fn as_str(&self) -> &'static str {
match *self { match *self {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
ContentEncoding::Br => "br", ContentEncoding::Br => "br",
ContentEncoding::Gzip => "gzip", ContentEncoding::Gzip => "gzip",
ContentEncoding::Deflate => "deflate", ContentEncoding::Deflate => "deflate",
@ -147,7 +147,7 @@ impl ContentEncoding {
/// default quality value /// default quality value
pub fn quality(&self) -> f64 { pub fn quality(&self) -> f64 {
match *self { match *self {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
ContentEncoding::Br => 1.1, ContentEncoding::Br => 1.1,
ContentEncoding::Gzip => 1.0, ContentEncoding::Gzip => 1.0,
ContentEncoding::Deflate => 0.9, ContentEncoding::Deflate => 0.9,
@ -160,7 +160,7 @@ impl ContentEncoding {
impl<'a> From<&'a str> for ContentEncoding { impl<'a> From<&'a str> for ContentEncoding {
fn from(s: &'a str) -> ContentEncoding { fn from(s: &'a str) -> ContentEncoding {
match s.trim().to_lowercase().as_ref() { match s.trim().to_lowercase().as_ref() {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
"br" => ContentEncoding::Br, "br" => ContentEncoding::Br,
"gzip" => ContentEncoding::Gzip, "gzip" => ContentEncoding::Gzip,
"deflate" => ContentEncoding::Deflate, "deflate" => ContentEncoding::Deflate,
@ -176,7 +176,9 @@ pub(crate) struct Writer {
impl Writer { impl Writer {
fn new() -> Writer { fn new() -> Writer {
Writer{buf: BytesMut::new()} Writer {
buf: BytesMut::new(),
}
} }
fn take(&mut self) -> Bytes { fn take(&mut self) -> Bytes {
self.buf.take().freeze() self.buf.take().freeze()
@ -199,18 +201,20 @@ impl fmt::Write for Writer {
#[inline] #[inline]
#[doc(hidden)] #[doc(hidden)]
/// Reads a comma-delimited raw header into a Vec. /// Reads a comma-delimited raw header into a Vec.
pub fn from_comma_delimited<T: FromStr>(all: GetAll<HeaderValue>) pub fn from_comma_delimited<T: FromStr>(
-> Result<Vec<T>, ParseError> all: GetAll<HeaderValue>
{ ) -> Result<Vec<T>, ParseError> {
let mut result = Vec::new(); let mut result = Vec::new();
for h in all { for h in all {
let s = h.to_str().map_err(|_| ParseError::Header)?; let s = h.to_str().map_err(|_| ParseError::Header)?;
result.extend(s.split(',') result.extend(
.filter_map(|x| match x.trim() { s.split(',')
"" => None, .filter_map(|x| match x.trim() {
y => Some(y) "" => None,
}) y => Some(y),
.filter_map(|x| x.trim().parse().ok())) })
.filter_map(|x| x.trim().parse().ok()),
)
} }
Ok(result) Ok(result)
} }
@ -218,13 +222,11 @@ pub fn from_comma_delimited<T: FromStr>(all: GetAll<HeaderValue>)
#[inline] #[inline]
#[doc(hidden)] #[doc(hidden)]
/// Reads a single string when parsing a header. /// Reads a single string when parsing a header.
pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>) pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>) -> Result<T, ParseError> {
-> Result<T, ParseError>
{
if let Some(line) = val { if let Some(line) = val {
let line = line.to_str().map_err(|_| ParseError::Header)?; let line = line.to_str().map_err(|_| ParseError::Header)?;
if !line.is_empty() { if !line.is_empty() {
return T::from_str(line).or(Err(ParseError::Header)) return T::from_str(line).or(Err(ParseError::Header));
} }
} }
Err(ParseError::Header) Err(ParseError::Header)
@ -234,7 +236,8 @@ pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>)
#[doc(hidden)] #[doc(hidden)]
/// Format an array into a comma-delimited string. /// Format an array into a comma-delimited string.
pub fn fmt_comma_delimited<T>(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result pub fn fmt_comma_delimited<T>(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result
where T: fmt::Display where
T: fmt::Display,
{ {
let mut iter = parts.iter(); let mut iter = parts.iter();
if let Some(part) = iter.next() { if let Some(part) = iter.next() {

View File

@ -1,7 +1,7 @@
#![allow(unused, deprecated)] #![allow(unused, deprecated)]
use std::ascii::AsciiExt;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::str::FromStr; use std::str::FromStr;
use std::ascii::AsciiExt;
use self::Charset::*; use self::Charset::*;
@ -12,9 +12,9 @@ use self::Charset::*;
/// See [http://www.iana.org/assignments/character-sets/character-sets.xhtml][url]. /// See [http://www.iana.org/assignments/character-sets/character-sets.xhtml][url].
/// ///
/// [url]: http://www.iana.org/assignments/character-sets/character-sets.xhtml /// [url]: http://www.iana.org/assignments/character-sets/character-sets.xhtml
#[derive(Clone,Debug,PartialEq)] #[derive(Clone, Debug, PartialEq)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub enum Charset{ pub enum Charset {
/// US ASCII /// US ASCII
Us_Ascii, Us_Ascii,
/// ISO-8859-1 /// ISO-8859-1
@ -64,7 +64,7 @@ pub enum Charset{
/// KOI8-R /// KOI8-R
Koi8_R, Koi8_R,
/// An arbitrary charset specified as a string /// An arbitrary charset specified as a string
Ext(String) Ext(String),
} }
impl Charset { impl Charset {
@ -94,7 +94,7 @@ impl Charset {
Gb2312 => "GB2312", Gb2312 => "GB2312",
Big5 => "5", Big5 => "5",
Koi8_R => "KOI8-R", Koi8_R => "KOI8-R",
Ext(ref s) => s Ext(ref s) => s,
} }
} }
} }
@ -133,18 +133,18 @@ impl FromStr for Charset {
"GB2312" => Gb2312, "GB2312" => Gb2312,
"5" => Big5, "5" => Big5,
"KOI8-R" => Koi8_R, "KOI8-R" => Koi8_R,
s => Ext(s.to_owned()) s => Ext(s.to_owned()),
}) })
} }
} }
#[test] #[test]
fn test_parse() { fn test_parse() {
assert_eq!(Us_Ascii,"us-ascii".parse().unwrap()); assert_eq!(Us_Ascii, "us-ascii".parse().unwrap());
assert_eq!(Us_Ascii,"US-Ascii".parse().unwrap()); assert_eq!(Us_Ascii, "US-Ascii".parse().unwrap());
assert_eq!(Us_Ascii,"US-ASCII".parse().unwrap()); assert_eq!(Us_Ascii, "US-ASCII".parse().unwrap());
assert_eq!(Shift_Jis,"Shift-JIS".parse().unwrap()); assert_eq!(Shift_Jis, "Shift-JIS".parse().unwrap());
assert_eq!(Ext("ABCD".to_owned()),"abcd".parse().unwrap()); assert_eq!(Ext("ABCD".to_owned()), "abcd".parse().unwrap());
} }
#[test] #[test]

View File

@ -1,7 +1,8 @@
use std::fmt; use std::fmt;
use std::str; use std::str;
pub use self::Encoding::{Chunked, Brotli, Gzip, Deflate, Compress, Identity, EncodingExt, Trailers}; pub use self::Encoding::{Brotli, Chunked, Compress, Deflate, EncodingExt, Gzip,
Identity, Trailers};
/// A value to represent an encoding used in `Transfer-Encoding` /// A value to represent an encoding used in `Transfer-Encoding`
/// or `Accept-Encoding` header. /// or `Accept-Encoding` header.
@ -22,7 +23,7 @@ pub enum Encoding {
/// The `trailers` encoding. /// The `trailers` encoding.
Trailers, Trailers,
/// Some other encoding that is less common, can be any String. /// Some other encoding that is less common, can be any String.
EncodingExt(String) EncodingExt(String),
} }
impl fmt::Display for Encoding { impl fmt::Display for Encoding {
@ -35,7 +36,7 @@ impl fmt::Display for Encoding {
Compress => "compress", Compress => "compress",
Identity => "identity", Identity => "identity",
Trailers => "trailers", Trailers => "trailers",
EncodingExt(ref s) => s.as_ref() EncodingExt(ref s) => s.as_ref(),
}) })
} }
} }
@ -51,7 +52,7 @@ impl str::FromStr for Encoding {
"compress" => Ok(Compress), "compress" => Ok(Compress),
"identity" => Ok(Identity), "identity" => Ok(Identity),
"trailers" => Ok(Trailers), "trailers" => Ok(Trailers),
_ => Ok(EncodingExt(s.to_owned())) _ => Ok(EncodingExt(s.to_owned())),
} }
} }
} }

View File

@ -1,21 +1,23 @@
use std::str::FromStr; use header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer};
use std::fmt::{self, Display, Write}; use std::fmt::{self, Display, Write};
use header::{HeaderValue, Writer, IntoHeaderValue, InvalidHeaderValueBytes}; use std::str::FromStr;
/// check that each char in the slice is either: /// check that each char in the slice is either:
/// 1. `%x21`, or /// 1. `%x21`, or
/// 2. in the range `%x23` to `%x7E`, or /// 2. in the range `%x23` to `%x7E`, or
/// 3. above `%x80` /// 3. above `%x80`
fn check_slice_validity(slice: &str) -> bool { fn check_slice_validity(slice: &str) -> bool {
slice.bytes().all(|c| slice
c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80')) .bytes()
.all(|c| c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80'))
} }
/// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3) /// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3)
/// ///
/// An entity tag consists of a string enclosed by two literal double quotes. /// An entity tag consists of a string enclosed by two literal double quotes.
/// Preceding the first double quote is an optional weakness indicator, /// Preceding the first double quote is an optional weakness indicator,
/// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and `W/"xyzzy"`. /// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and
/// `W/"xyzzy"`.
/// ///
/// # ABNF /// # ABNF
/// ///
@ -28,9 +30,9 @@ fn check_slice_validity(slice: &str) -> bool {
/// ``` /// ```
/// ///
/// # Comparison /// # Comparison
/// To check if two entity tags are equivalent in an application always use the `strong_eq` or /// To check if two entity tags are equivalent in an application always use the
/// `weak_eq` methods based on the context of the Tag. Only use `==` to check if two tags are /// `strong_eq` or `weak_eq` methods based on the context of the Tag. Only use
/// identical. /// `==` to check if two tags are identical.
/// ///
/// The example below shows the results for a set of entity-tag pairs and /// The example below shows the results for a set of entity-tag pairs and
/// both the weak and strong comparison function results: /// both the weak and strong comparison function results:
@ -46,7 +48,7 @@ pub struct EntityTag {
/// Weakness indicator for the tag /// Weakness indicator for the tag
pub weak: bool, pub weak: bool,
/// The opaque string in between the DQUOTEs /// The opaque string in between the DQUOTEs
tag: String tag: String,
} }
impl EntityTag { impl EntityTag {
@ -85,8 +87,8 @@ impl EntityTag {
self.tag = tag self.tag = tag
} }
/// For strong comparison two entity-tags are equivalent if both are not weak and their /// For strong comparison two entity-tags are equivalent if both are not
/// opaque-tags match character-by-character. /// weak and their opaque-tags match character-by-character.
pub fn strong_eq(&self, other: &EntityTag) -> bool { pub fn strong_eq(&self, other: &EntityTag) -> bool {
!self.weak && !other.weak && self.tag == other.tag !self.weak && !other.weak && self.tag == other.tag
} }
@ -131,13 +133,21 @@ impl FromStr for EntityTag {
} }
// The etag is weak if its first char is not a DQUOTE. // The etag is weak if its first char is not a DQUOTE.
if slice.len() >= 2 && slice.starts_with('"') if slice.len() >= 2 && slice.starts_with('"')
&& check_slice_validity(&slice[1..length-1]) { && check_slice_validity(&slice[1..length - 1])
{
// No need to check if the last char is a DQUOTE, // No need to check if the last char is a DQUOTE,
// we already did that above. // we already did that above.
return Ok(EntityTag { weak: false, tag: slice[1..length-1].to_owned() }); return Ok(EntityTag {
weak: false,
tag: slice[1..length - 1].to_owned(),
});
} else if slice.len() >= 4 && slice.starts_with("W/\"") } else if slice.len() >= 4 && slice.starts_with("W/\"")
&& check_slice_validity(&slice[3..length-1]) { && check_slice_validity(&slice[3..length - 1])
return Ok(EntityTag { weak: true, tag: slice[3..length-1].to_owned() }); {
return Ok(EntityTag {
weak: true,
tag: slice[3..length - 1].to_owned(),
});
} }
Err(::error::ParseError::Header) Err(::error::ParseError::Header)
} }
@ -149,7 +159,7 @@ impl IntoHeaderValue for EntityTag {
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = Writer::new(); let mut wrt = Writer::new();
write!(wrt, "{}", self).unwrap(); write!(wrt, "{}", self).unwrap();
unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.take()))} unsafe { Ok(HeaderValue::from_shared_unchecked(wrt.take())) }
} }
} }
@ -160,22 +170,37 @@ mod tests {
#[test] #[test]
fn test_etag_parse_success() { fn test_etag_parse_success() {
// Expected success // Expected success
assert_eq!("\"foobar\"".parse::<EntityTag>().unwrap(), assert_eq!(
EntityTag::strong("foobar".to_owned())); "\"foobar\"".parse::<EntityTag>().unwrap(),
assert_eq!("\"\"".parse::<EntityTag>().unwrap(), EntityTag::strong("foobar".to_owned())
EntityTag::strong("".to_owned())); );
assert_eq!("W/\"weaktag\"".parse::<EntityTag>().unwrap(), assert_eq!(
EntityTag::weak("weaktag".to_owned())); "\"\"".parse::<EntityTag>().unwrap(),
assert_eq!("W/\"\x65\x62\"".parse::<EntityTag>().unwrap(), EntityTag::strong("".to_owned())
EntityTag::weak("\x65\x62".to_owned())); );
assert_eq!("W/\"\"".parse::<EntityTag>().unwrap(), EntityTag::weak("".to_owned())); assert_eq!(
"W/\"weaktag\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("weaktag".to_owned())
);
assert_eq!(
"W/\"\x65\x62\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("\x65\x62".to_owned())
);
assert_eq!(
"W/\"\"".parse::<EntityTag>().unwrap(),
EntityTag::weak("".to_owned())
);
} }
#[test] #[test]
fn test_etag_parse_failures() { fn test_etag_parse_failures() {
// Expected failures // Expected failures
assert!("no-dquotes".parse::<EntityTag>().is_err()); assert!("no-dquotes".parse::<EntityTag>().is_err());
assert!("w/\"the-first-w-is-case-sensitive\"".parse::<EntityTag>().is_err()); assert!(
"w/\"the-first-w-is-case-sensitive\""
.parse::<EntityTag>()
.is_err()
);
assert!("".parse::<EntityTag>().is_err()); assert!("".parse::<EntityTag>().is_err());
assert!("\"unmatched-dquotes1".parse::<EntityTag>().is_err()); assert!("\"unmatched-dquotes1".parse::<EntityTag>().is_err());
assert!("unmatched-dquotes2\"".parse::<EntityTag>().is_err()); assert!("unmatched-dquotes2\"".parse::<EntityTag>().is_err());
@ -184,11 +209,26 @@ mod tests {
#[test] #[test]
fn test_etag_fmt() { fn test_etag_fmt() {
assert_eq!(format!("{}", EntityTag::strong("foobar".to_owned())), "\"foobar\""); assert_eq!(
assert_eq!(format!("{}", EntityTag::strong("".to_owned())), "\"\""); format!("{}", EntityTag::strong("foobar".to_owned())),
assert_eq!(format!("{}", EntityTag::weak("weak-etag".to_owned())), "W/\"weak-etag\""); "\"foobar\""
assert_eq!(format!("{}", EntityTag::weak("\u{0065}".to_owned())), "W/\"\x65\""); );
assert_eq!(format!("{}", EntityTag::weak("".to_owned())), "W/\"\""); assert_eq!(
format!("{}", EntityTag::strong("".to_owned())),
"\"\""
);
assert_eq!(
format!("{}", EntityTag::weak("weak-etag".to_owned())),
"W/\"weak-etag\""
);
assert_eq!(
format!("{}", EntityTag::weak("\u{0065}".to_owned())),
"W/\"\x65\""
);
assert_eq!(
format!("{}", EntityTag::weak("".to_owned())),
"W/\"\""
);
} }
#[test] #[test]

View File

@ -3,14 +3,13 @@ use std::io::Write;
use std::str::FromStr; use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use time; use bytes::{BufMut, BytesMut};
use bytes::{BytesMut, BufMut};
use http::header::{HeaderValue, InvalidHeaderValueBytes}; use http::header::{HeaderValue, InvalidHeaderValueBytes};
use time;
use error::ParseError; use error::ParseError;
use header::IntoHeaderValue; use header::IntoHeaderValue;
/// A timestamp with HTTP formatting and parsing /// A timestamp with HTTP formatting and parsing
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct HttpDate(time::Tm); pub struct HttpDate(time::Tm);
@ -19,11 +18,10 @@ impl FromStr for HttpDate {
type Err = ParseError; type Err = ParseError;
fn from_str(s: &str) -> Result<HttpDate, ParseError> { fn from_str(s: &str) -> Result<HttpDate, ParseError> {
match time::strptime(s, "%a, %d %b %Y %T %Z").or_else(|_| { match time::strptime(s, "%a, %d %b %Y %T %Z")
time::strptime(s, "%A, %d-%b-%y %T %Z") .or_else(|_| time::strptime(s, "%A, %d-%b-%y %T %Z"))
}).or_else(|_| { .or_else(|_| time::strptime(s, "%c"))
time::strptime(s, "%c") {
}) {
Ok(t) => Ok(HttpDate(t)), Ok(t) => Ok(HttpDate(t)),
Err(_) => Err(ParseError::Header), Err(_) => Err(ParseError::Header),
} }
@ -47,11 +45,14 @@ impl From<SystemTime> for HttpDate {
let tmspec = match sys.duration_since(UNIX_EPOCH) { let tmspec = match sys.duration_since(UNIX_EPOCH) {
Ok(dur) => { Ok(dur) => {
time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32) time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32)
}, }
Err(err) => { Err(err) => {
let neg = err.duration(); let neg = err.duration();
time::Timespec::new(-(neg.as_secs() as i64), -(neg.subsec_nanos() as i32)) time::Timespec::new(
}, -(neg.as_secs() as i64),
-(neg.subsec_nanos() as i32),
)
}
}; };
HttpDate(time::at_utc(tmspec)) HttpDate(time::at_utc(tmspec))
} }
@ -63,7 +64,11 @@ impl IntoHeaderValue for HttpDate {
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = BytesMut::with_capacity(29).writer(); let mut wrt = BytesMut::with_capacity(29).writer();
write!(wrt, "{}", self.0.rfc822()).unwrap(); write!(wrt, "{}", self.0.rfc822()).unwrap();
unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.get_mut().take().freeze()))} unsafe {
Ok(HeaderValue::from_shared_unchecked(
wrt.get_mut().take().freeze(),
))
}
} }
} }
@ -80,18 +85,43 @@ impl From<HttpDate> for SystemTime {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use time::Tm;
use super::HttpDate; use super::HttpDate;
use time::Tm;
const NOV_07: HttpDate = HttpDate(Tm { const NOV_07: HttpDate = HttpDate(Tm {
tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8, tm_mday: 7, tm_mon: 10, tm_year: 94, tm_nsec: 0,
tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0}); tm_sec: 37,
tm_min: 48,
tm_hour: 8,
tm_mday: 7,
tm_mon: 10,
tm_year: 94,
tm_wday: 0,
tm_isdst: 0,
tm_yday: 0,
tm_utcoff: 0,
});
#[test] #[test]
fn test_date() { fn test_date() {
assert_eq!("Sun, 07 Nov 1994 08:48:37 GMT".parse::<HttpDate>().unwrap(), NOV_07); assert_eq!(
assert_eq!("Sunday, 07-Nov-94 08:48:37 GMT".parse::<HttpDate>().unwrap(), NOV_07); "Sun, 07 Nov 1994 08:48:37 GMT"
assert_eq!("Sun Nov 7 08:48:37 1994".parse::<HttpDate>().unwrap(), NOV_07); .parse::<HttpDate>()
.unwrap(),
NOV_07
);
assert_eq!(
"Sunday, 07-Nov-94 08:48:37 GMT"
.parse::<HttpDate>()
.unwrap(),
NOV_07
);
assert_eq!(
"Sun Nov 7 08:48:37 1994"
.parse::<HttpDate>()
.unwrap(),
NOV_07
);
assert!("this-is-no-date".parse::<HttpDate>().is_err()); assert!("this-is-no-date".parse::<HttpDate>().is_err());
} }
} }

View File

@ -4,11 +4,11 @@ pub use self::charset::Charset;
pub use self::encoding::Encoding; pub use self::encoding::Encoding;
pub use self::entity::EntityTag; pub use self::entity::EntityTag;
pub use self::httpdate::HttpDate; pub use self::httpdate::HttpDate;
pub use self::quality_item::{q, qitem, Quality, QualityItem};
pub use language_tags::LanguageTag; pub use language_tags::LanguageTag;
pub use self::quality_item::{Quality, QualityItem, qitem, q};
mod charset; mod charset;
mod entity;
mod encoding; mod encoding;
mod entity;
mod httpdate; mod httpdate;
mod quality_item; mod quality_item;

View File

@ -13,11 +13,13 @@ use self::internal::IntoQuality;
/// ///
/// # Implementation notes /// # Implementation notes
/// ///
/// The quality value is defined as a number between 0 and 1 with three decimal places. This means /// The quality value is defined as a number between 0 and 1 with three decimal
/// there are 1001 possible values. Since floating point numbers are not exact and the smallest /// places. This means there are 1001 possible values. Since floating point
/// floating point data type (`f32`) consumes four bytes, hyper uses an `u16` value to store the /// numbers are not exact and the smallest floating point data type (`f32`)
/// quality internally. For performance reasons you may set quality directly to a value between /// consumes four bytes, hyper uses an `u16` value to store the
/// 0 and 1000 e.g. `Quality(532)` matches the quality `q=0.532`. /// quality internally. For performance reasons you may set quality directly to
/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality
/// `q=0.532`.
/// ///
/// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1) /// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1)
/// gives more information on quality values in HTTP header fields. /// gives more information on quality values in HTTP header fields.
@ -61,7 +63,11 @@ impl<T: fmt::Display> fmt::Display for QualityItem<T> {
match self.quality.0 { match self.quality.0 {
1000 => Ok(()), 1000 => Ok(()),
0 => f.write_str("; q=0"), 0 => f.write_str("; q=0"),
x => write!(f, "; q=0.{}", format!("{:03}", x).trim_right_matches('0')) x => write!(
f,
"; q=0.{}",
format!("{:03}", x).trim_right_matches('0')
),
} }
} }
} }
@ -96,7 +102,7 @@ impl<T: str::FromStr> str::FromStr for QualityItem<T> {
} else { } else {
return Err(::error::ParseError::Header); return Err(::error::ParseError::Header);
} }
}, }
Err(_) => return Err(::error::ParseError::Header), Err(_) => return Err(::error::ParseError::Header),
} }
} }
@ -114,7 +120,10 @@ fn from_f32(f: f32) -> Quality {
// this function is only used internally. A check that `f` is within range // this function is only used internally. A check that `f` is within range
// should be done before calling this method. Just in case, this // should be done before calling this method. Just in case, this
// debug_assert should catch if we were forgetful // debug_assert should catch if we were forgetful
debug_assert!(f >= 0f32 && f <= 1f32, "q value must be between 0.0 and 1.0"); debug_assert!(
f >= 0f32 && f <= 1f32,
"q value must be between 0.0 and 1.0"
);
Quality((f * 1000f32) as u16) Quality((f * 1000f32) as u16)
} }
@ -125,7 +134,7 @@ pub fn qitem<T>(item: T) -> QualityItem<T> {
} }
/// Convenience function to create a `Quality` from a float or integer. /// Convenience function to create a `Quality` from a float or integer.
/// ///
/// Implemented for `u16` and `f32`. Panics if value is out of range. /// Implemented for `u16` and `f32`. Panics if value is out of range.
pub fn q<T: IntoQuality>(val: T) -> Quality { pub fn q<T: IntoQuality>(val: T) -> Quality {
val.into_quality() val.into_quality()
@ -147,7 +156,10 @@ mod internal {
impl IntoQuality for f32 { impl IntoQuality for f32 {
fn into_quality(self) -> Quality { fn into_quality(self) -> Quality {
assert!(self >= 0f32 && self <= 1f32, "float must be between 0.0 and 1.0"); assert!(
self >= 0f32 && self <= 1f32,
"float must be between 0.0 and 1.0"
);
super::from_f32(self) super::from_f32(self)
} }
} }
@ -159,7 +171,6 @@ mod internal {
} }
} }
pub trait Sealed {} pub trait Sealed {}
impl Sealed for u16 {} impl Sealed for u16 {}
impl Sealed for f32 {} impl Sealed for f32 {}
@ -167,8 +178,8 @@ mod internal {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use super::super::encoding::*; use super::super::encoding::*;
use super::*;
#[test] #[test]
fn test_quality_item_fmt_q_1() { fn test_quality_item_fmt_q_1() {
@ -183,7 +194,7 @@ mod tests {
#[test] #[test]
fn test_quality_item_fmt_q_05() { fn test_quality_item_fmt_q_05() {
// Custom value // Custom value
let x = QualityItem{ let x = QualityItem {
item: EncodingExt("identity".to_owned()), item: EncodingExt("identity".to_owned()),
quality: Quality(500), quality: Quality(500),
}; };
@ -193,7 +204,7 @@ mod tests {
#[test] #[test]
fn test_quality_item_fmt_q_0() { fn test_quality_item_fmt_q_0() {
// Custom value // Custom value
let x = QualityItem{ let x = QualityItem {
item: EncodingExt("identity".to_owned()), item: EncodingExt("identity".to_owned()),
quality: Quality(0), quality: Quality(0),
}; };
@ -203,22 +214,46 @@ mod tests {
#[test] #[test]
fn test_quality_item_from_str1() { fn test_quality_item_from_str1() {
let x: Result<QualityItem<Encoding>, _> = "chunked".parse(); let x: Result<QualityItem<Encoding>, _> = "chunked".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Chunked, quality: Quality(1000), }); assert_eq!(
x.unwrap(),
QualityItem {
item: Chunked,
quality: Quality(1000),
}
);
} }
#[test] #[test]
fn test_quality_item_from_str2() { fn test_quality_item_from_str2() {
let x: Result<QualityItem<Encoding>, _> = "chunked; q=1".parse(); let x: Result<QualityItem<Encoding>, _> = "chunked; q=1".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Chunked, quality: Quality(1000), }); assert_eq!(
x.unwrap(),
QualityItem {
item: Chunked,
quality: Quality(1000),
}
);
} }
#[test] #[test]
fn test_quality_item_from_str3() { fn test_quality_item_from_str3() {
let x: Result<QualityItem<Encoding>, _> = "gzip; q=0.5".parse(); let x: Result<QualityItem<Encoding>, _> = "gzip; q=0.5".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Gzip, quality: Quality(500), }); assert_eq!(
x.unwrap(),
QualityItem {
item: Gzip,
quality: Quality(500),
}
);
} }
#[test] #[test]
fn test_quality_item_from_str4() { fn test_quality_item_from_str4() {
let x: Result<QualityItem<Encoding>, _> = "gzip; q=0.273".parse(); let x: Result<QualityItem<Encoding>, _> = "gzip; q=0.273".parse();
assert_eq!(x.unwrap(), QualityItem{ item: Gzip, quality: Quality(273), }); assert_eq!(
x.unwrap(),
QualityItem {
item: Gzip,
quality: Quality(273),
}
);
} }
#[test] #[test]
fn test_quality_item_from_str5() { fn test_quality_item_from_str5() {
@ -245,14 +280,14 @@ mod tests {
#[test] #[test]
#[should_panic] // FIXME - 32-bit msvc unwinding broken #[should_panic] // FIXME - 32-bit msvc unwinding broken
#[cfg_attr(all(target_arch="x86", target_env="msvc"), ignore)] #[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)]
fn test_quality_invalid() { fn test_quality_invalid() {
q(-1.0); q(-1.0);
} }
#[test] #[test]
#[should_panic] // FIXME - 32-bit msvc unwinding broken #[should_panic] // FIXME - 32-bit msvc unwinding broken
#[cfg_attr(all(target_arch="x86", target_env="msvc"), ignore)] #[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)]
fn test_quality_invalid2() { fn test_quality_invalid2() {
q(2.0); q(2.0);
} }
@ -260,6 +295,10 @@ mod tests {
#[test] #[test]
fn test_fuzzing_bugs() { fn test_fuzzing_bugs() {
assert!("99999;".parse::<QualityItem<String>>().is_err()); assert!("99999;".parse::<QualityItem<String>>().is_err());
assert!("\x0d;;;=\u{d6aa}==".parse::<QualityItem<String>>().is_err()) assert!(
"\x0d;;;=\u{d6aa}=="
.parse::<QualityItem<String>>()
.is_err()
)
} }
} }

View File

@ -1,7 +1,7 @@
//! Various helpers //! Various helpers
use regex::Regex;
use http::{header, StatusCode}; use http::{header, StatusCode};
use regex::Regex;
use handler::Handler; use handler::Handler;
use httprequest::HttpRequest; use httprequest::HttpRequest;
@ -24,9 +24,11 @@ use httpresponse::HttpResponse;
/// defined with trailing slash and the request comes without it, it will /// defined with trailing slash and the request comes without it, it will
/// append it automatically. /// append it automatically.
/// ///
/// If *merge* is *true*, merge multiple consecutive slashes in the path into one. /// If *merge* is *true*, merge multiple consecutive slashes in the path into
/// one.
/// ///
/// This handler designed to be use as a handler for application's *default resource*. /// This handler designed to be use as a handler for application's *default
/// resource*.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_web; /// # extern crate actix_web;
@ -55,7 +57,8 @@ pub struct NormalizePath {
impl Default for NormalizePath { impl Default for NormalizePath {
/// Create default `NormalizePath` instance, *append* is set to *true*, /// Create default `NormalizePath` instance, *append* is set to *true*,
/// *merge* is set to *true* and *redirect* is set to `StatusCode::MOVED_PERMANENTLY` /// *merge* is set to *true* and *redirect* is set to
/// `StatusCode::MOVED_PERMANENTLY`
fn default() -> NormalizePath { fn default() -> NormalizePath {
NormalizePath { NormalizePath {
append: true, append: true,
@ -91,7 +94,11 @@ impl<S> Handler<S> for NormalizePath {
let p = self.re_merge.replace_all(req.path(), "/"); let p = self.re_merge.replace_all(req.path(), "/");
if p.len() != req.path().len() { if p.len() != req.path().len() {
if router.has_route(p.as_ref()) { if router.has_route(p.as_ref()) {
let p = if !query.is_empty() { p + "?" + query } else { p }; let p = if !query.is_empty() {
p + "?" + query
} else {
p
};
return HttpResponse::build(self.redirect) return HttpResponse::build(self.redirect)
.header(header::LOCATION, p.as_ref()) .header(header::LOCATION, p.as_ref())
.finish(); .finish();
@ -100,10 +107,14 @@ impl<S> Handler<S> for NormalizePath {
if self.append && !p.ends_with('/') { if self.append && !p.ends_with('/') {
let p = p.as_ref().to_owned() + "/"; let p = p.as_ref().to_owned() + "/";
if router.has_route(&p) { if router.has_route(&p) {
let p = if !query.is_empty() { p + "?" + query } else { p }; let p = if !query.is_empty() {
p + "?" + query
} else {
p
};
return HttpResponse::build(self.redirect) return HttpResponse::build(self.redirect)
.header(header::LOCATION, p.as_str()) .header(header::LOCATION, p.as_str())
.finish() .finish();
} }
} }
@ -113,11 +124,13 @@ impl<S> Handler<S> for NormalizePath {
if router.has_route(p) { if router.has_route(p) {
let mut req = HttpResponse::build(self.redirect); let mut req = HttpResponse::build(self.redirect);
return if !query.is_empty() { return if !query.is_empty() {
req.header(header::LOCATION, (p.to_owned() + "?" + query).as_str()) req.header(
header::LOCATION,
(p.to_owned() + "?" + query).as_str(),
)
} else { } else {
req.header(header::LOCATION, p) req.header(header::LOCATION, p)
} }.finish();
.finish();
} }
} }
} else if p.ends_with('/') { } else if p.ends_with('/') {
@ -126,12 +139,13 @@ impl<S> Handler<S> for NormalizePath {
if router.has_route(p) { if router.has_route(p) {
let mut req = HttpResponse::build(self.redirect); let mut req = HttpResponse::build(self.redirect);
return if !query.is_empty() { return if !query.is_empty() {
req.header(header::LOCATION, req.header(
(p.to_owned() + "?" + query).as_str()) header::LOCATION,
(p.to_owned() + "?" + query).as_str(),
)
} else { } else {
req.header(header::LOCATION, p) req.header(header::LOCATION, p)
} }.finish();
.finish();
} }
} }
} }
@ -139,7 +153,11 @@ impl<S> Handler<S> for NormalizePath {
if self.append && !req.path().ends_with('/') { if self.append && !req.path().ends_with('/') {
let p = req.path().to_owned() + "/"; let p = req.path().to_owned() + "/";
if router.has_route(&p) { if router.has_route(&p) {
let p = if !query.is_empty() { p + "?" + query } else { p }; let p = if !query.is_empty() {
p + "?" + query
} else {
p
};
return HttpResponse::build(self.redirect) return HttpResponse::build(self.redirect)
.header(header::LOCATION, p.as_str()) .header(header::LOCATION, p.as_str())
.finish(); .finish();
@ -153,9 +171,9 @@ impl<S> Handler<S> for NormalizePath {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use application::App;
use http::{header, Method}; use http::{header, Method};
use test::TestRequest; use test::TestRequest;
use application::App;
fn index(_req: HttpRequest) -> HttpResponse { fn index(_req: HttpRequest) -> HttpResponse {
HttpResponse::new(StatusCode::OK) HttpResponse::new(StatusCode::OK)
@ -170,17 +188,32 @@ mod tests {
.finish(); .finish();
// trailing slashes // trailing slashes
let params = let params = vec![
vec![("/resource1", "", StatusCode::OK), ("/resource1", "", StatusCode::OK),
("/resource1/", "/resource1", StatusCode::MOVED_PERMANENTLY), (
("/resource2", "/resource2/", StatusCode::MOVED_PERMANENTLY), "/resource1/",
("/resource2/", "", StatusCode::OK), "/resource1",
("/resource1?p1=1&p2=2", "", StatusCode::OK), StatusCode::MOVED_PERMANENTLY,
("/resource1/?p1=1&p2=2", "/resource1?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY), ),
("/resource2?p1=1&p2=2", "/resource2/?p1=1&p2=2", (
StatusCode::MOVED_PERMANENTLY), "/resource2",
("/resource2/?p1=1&p2=2", "", StatusCode::OK) "/resource2/",
]; StatusCode::MOVED_PERMANENTLY,
),
("/resource2/", "", StatusCode::OK),
("/resource1?p1=1&p2=2", "", StatusCode::OK),
(
"/resource1/?p1=1&p2=2",
"/resource1?p1=1&p2=2",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2?p1=1&p2=2",
"/resource2/?p1=1&p2=2",
StatusCode::MOVED_PERMANENTLY,
),
("/resource2/?p1=1&p2=2", "", StatusCode::OK),
];
for (path, target, code) in params { for (path, target, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish()); let req = app.prepare_request(TestRequest::with_uri(path).finish());
let resp = app.run(req); let resp = app.run(req);
@ -189,7 +222,12 @@ mod tests {
if !target.is_empty() { if !target.is_empty() {
assert_eq!( assert_eq!(
target, target,
r.headers().get(header::LOCATION).unwrap().to_str().unwrap()); r.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
);
} }
} }
} }
@ -199,19 +237,25 @@ mod tests {
let mut app = App::new() let mut app = App::new()
.resource("/resource1", |r| r.method(Method::GET).f(index)) .resource("/resource1", |r| r.method(Method::GET).f(index))
.resource("/resource2/", |r| r.method(Method::GET).f(index)) .resource("/resource2/", |r| r.method(Method::GET).f(index))
.default_resource(|r| r.h( .default_resource(|r| {
NormalizePath::new(false, true, StatusCode::MOVED_PERMANENTLY))) r.h(NormalizePath::new(
false,
true,
StatusCode::MOVED_PERMANENTLY,
))
})
.finish(); .finish();
// trailing slashes // trailing slashes
let params = vec![("/resource1", StatusCode::OK), let params = vec![
("/resource1/", StatusCode::MOVED_PERMANENTLY), ("/resource1", StatusCode::OK),
("/resource2", StatusCode::NOT_FOUND), ("/resource1/", StatusCode::MOVED_PERMANENTLY),
("/resource2/", StatusCode::OK), ("/resource2", StatusCode::NOT_FOUND),
("/resource1?p1=1&p2=2", StatusCode::OK), ("/resource2/", StatusCode::OK),
("/resource1/?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY), ("/resource1?p1=1&p2=2", StatusCode::OK),
("/resource2?p1=1&p2=2", StatusCode::NOT_FOUND), ("/resource1/?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY),
("/resource2/?p1=1&p2=2", StatusCode::OK) ("/resource2?p1=1&p2=2", StatusCode::NOT_FOUND),
("/resource2/?p1=1&p2=2", StatusCode::OK),
]; ];
for (path, code) in params { for (path, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish()); let req = app.prepare_request(TestRequest::with_uri(path).finish());
@ -232,21 +276,77 @@ mod tests {
// trailing slashes // trailing slashes
let params = vec![ let params = vec![
("/resource1/a/b", "", StatusCode::OK), ("/resource1/a/b", "", StatusCode::OK),
("/resource1/", "/resource1", StatusCode::MOVED_PERMANENTLY), (
("/resource1//", "/resource1", StatusCode::MOVED_PERMANENTLY), "/resource1/",
("//resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), "/resource1",
("//resource1//a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("//resource1//a//b//", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), ),
("///resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), (
("/////resource1/a///b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), "/resource1//",
("/////resource1/a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), "/resource1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource1//a//b//",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a//b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
("/resource1/a/b?p=1", "", StatusCode::OK), ("/resource1/a/b?p=1", "", StatusCode::OK),
("//resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), (
("//resource1//a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), "//resource1//a//b?p=1",
("///resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), "/resource1/a/b?p=1",
("/////resource1/a///b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("/////resource1/a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), ),
("/////resource1/a//b//?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), (
"//resource1//a//b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a//b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a//b//?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
]; ];
for (path, target, code) in params { for (path, target, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish()); let req = app.prepare_request(TestRequest::with_uri(path).finish());
@ -256,7 +356,12 @@ mod tests {
if !target.is_empty() { if !target.is_empty() {
assert_eq!( assert_eq!(
target, target,
r.headers().get(header::LOCATION).unwrap().to_str().unwrap()); r.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
);
} }
} }
} }
@ -274,38 +379,158 @@ mod tests {
// trailing slashes // trailing slashes
let params = vec![ let params = vec![
("/resource1/a/b", "", StatusCode::OK), ("/resource1/a/b", "", StatusCode::OK),
("/resource1/a/b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), (
("//resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), "/resource1/a/b/",
("//resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), "/resource1/a/b",
("//resource2//a//b//", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("///resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), ),
("///resource1//a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), (
("/////resource1/a///b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), "//resource2//a//b",
("/////resource1/a///b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), "/resource2/a/b/",
("/resource2/a/b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b//",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b/",
"/resource1/a/b",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2/a/b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
("/resource2/a/b/", "", StatusCode::OK), ("/resource2/a/b/", "", StatusCode::OK),
("//resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), (
("//resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), "//resource2//a//b",
("///resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), "/resource2/a/b/",
("///resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("/////resource2/a///b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), ),
("/////resource2/a///b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), (
"//resource2//a//b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b/",
"/resource2/a/b/",
StatusCode::MOVED_PERMANENTLY,
),
("/resource1/a/b?p=1", "", StatusCode::OK), ("/resource1/a/b?p=1", "", StatusCode::OK),
("/resource1/a/b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), (
("//resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), "/resource1/a/b/?p=1",
("//resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), "/resource1/a/b?p=1",
("///resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("///resource1//a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), ),
("/////resource1/a///b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), (
("/////resource1/a///b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), "//resource2//a//b?p=1",
("/////resource1/a///b//?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), "/resource2/a/b/?p=1",
("/resource2/a/b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("//resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), ),
("//resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), (
("///resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), "//resource2//a//b/?p=1",
("///resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), "/resource2/a/b/?p=1",
("/////resource2/a///b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), StatusCode::MOVED_PERMANENTLY,
("/////resource2/a///b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), ),
(
"///resource1//a//b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource1//a//b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b/?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource1/a///b//?p=1",
"/resource1/a/b?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/resource2/a/b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"//resource2//a//b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"///resource2//a//b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
(
"/////resource2/a///b/?p=1",
"/resource2/a/b/?p=1",
StatusCode::MOVED_PERMANENTLY,
),
]; ];
for (path, target, code) in params { for (path, target, code) in params {
let req = app.prepare_request(TestRequest::with_uri(path).finish()); let req = app.prepare_request(TestRequest::with_uri(path).finish());
@ -314,7 +539,13 @@ mod tests {
assert_eq!(r.status(), code); assert_eq!(r.status(), code);
if !target.is_empty() { if !target.is_empty() {
assert_eq!( assert_eq!(
target, r.headers().get(header::LOCATION).unwrap().to_str().unwrap()); target,
r.headers()
.get(header::LOCATION)
.unwrap()
.to_str()
.unwrap()
);
} }
} }
} }

View File

@ -4,127 +4,155 @@ use http::StatusCode;
use body::Body; use body::Body;
use error::Error; use error::Error;
use handler::{Reply, Handler, RouteHandler, Responder}; use handler::{Handler, Reply, Responder, RouteHandler};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::{HttpResponse, HttpResponseBuilder}; use httpresponse::{HttpResponse, HttpResponseBuilder};
#[deprecated(since="0.5.0", note="please use `HttpResponse::Ok()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::Ok()` instead")]
pub const HttpOk: StaticResponse = StaticResponse(StatusCode::OK); pub const HttpOk: StaticResponse = StaticResponse(StatusCode::OK);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Created()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::Created()` instead")]
pub const HttpCreated: StaticResponse = StaticResponse(StatusCode::CREATED); pub const HttpCreated: StaticResponse = StaticResponse(StatusCode::CREATED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Accepted()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::Accepted()` instead")]
pub const HttpAccepted: StaticResponse = StaticResponse(StatusCode::ACCEPTED); pub const HttpAccepted: StaticResponse = StaticResponse(StatusCode::ACCEPTED);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::pNonAuthoritativeInformation()` instead")] note = "please use `HttpResponse::pNonAuthoritativeInformation()` instead")]
pub const HttpNonAuthoritativeInformation: StaticResponse = pub const HttpNonAuthoritativeInformation: StaticResponse =
StaticResponse(StatusCode::NON_AUTHORITATIVE_INFORMATION); StaticResponse(StatusCode::NON_AUTHORITATIVE_INFORMATION);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NoContent()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::NoContent()` instead")]
pub const HttpNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT); pub const HttpNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::ResetContent()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ResetContent()` instead")]
pub const HttpResetContent: StaticResponse = StaticResponse(StatusCode::RESET_CONTENT); pub const HttpResetContent: StaticResponse = StaticResponse(StatusCode::RESET_CONTENT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PartialContent()` instead")] #[deprecated(since = "0.5.0",
pub const HttpPartialContent: StaticResponse = StaticResponse(StatusCode::PARTIAL_CONTENT); note = "please use `HttpResponse::PartialContent()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::MultiStatus()` instead")] pub const HttpPartialContent: StaticResponse =
StaticResponse(StatusCode::PARTIAL_CONTENT);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::MultiStatus()` instead")]
pub const HttpMultiStatus: StaticResponse = StaticResponse(StatusCode::MULTI_STATUS); pub const HttpMultiStatus: StaticResponse = StaticResponse(StatusCode::MULTI_STATUS);
#[deprecated(since="0.5.0", note="please use `HttpResponse::AlreadyReported()` instead")] #[deprecated(since = "0.5.0",
pub const HttpAlreadyReported: StaticResponse = StaticResponse(StatusCode::ALREADY_REPORTED); note = "please use `HttpResponse::AlreadyReported()` instead")]
pub const HttpAlreadyReported: StaticResponse =
StaticResponse(StatusCode::ALREADY_REPORTED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::MultipleChoices()` instead")] #[deprecated(since = "0.5.0",
pub const HttpMultipleChoices: StaticResponse = StaticResponse(StatusCode::MULTIPLE_CHOICES); note = "please use `HttpResponse::MultipleChoices()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::MovedPermanently()` instead")] pub const HttpMultipleChoices: StaticResponse =
pub const HttpMovedPermanently: StaticResponse = StaticResponse(StatusCode::MOVED_PERMANENTLY); StaticResponse(StatusCode::MULTIPLE_CHOICES);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Found()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::MovedPermanently()` instead")]
pub const HttpMovedPermanently: StaticResponse =
StaticResponse(StatusCode::MOVED_PERMANENTLY);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Found()` instead")]
pub const HttpFound: StaticResponse = StaticResponse(StatusCode::FOUND); pub const HttpFound: StaticResponse = StaticResponse(StatusCode::FOUND);
#[deprecated(since="0.5.0", note="please use `HttpResponse::SeeOther()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::SeeOther()` instead")]
pub const HttpSeeOther: StaticResponse = StaticResponse(StatusCode::SEE_OTHER); pub const HttpSeeOther: StaticResponse = StaticResponse(StatusCode::SEE_OTHER);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotModified()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::NotModified()` instead")]
pub const HttpNotModified: StaticResponse = StaticResponse(StatusCode::NOT_MODIFIED); pub const HttpNotModified: StaticResponse = StaticResponse(StatusCode::NOT_MODIFIED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::UseProxy()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::UseProxy()` instead")]
pub const HttpUseProxy: StaticResponse = StaticResponse(StatusCode::USE_PROXY); pub const HttpUseProxy: StaticResponse = StaticResponse(StatusCode::USE_PROXY);
#[deprecated(since="0.5.0", note="please use `HttpResponse::TemporaryRedirect()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::TemporaryRedirect()` instead")]
pub const HttpTemporaryRedirect: StaticResponse = pub const HttpTemporaryRedirect: StaticResponse =
StaticResponse(StatusCode::TEMPORARY_REDIRECT); StaticResponse(StatusCode::TEMPORARY_REDIRECT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PermanentRedirect()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PermanentRedirect()` instead")]
pub const HttpPermanentRedirect: StaticResponse = pub const HttpPermanentRedirect: StaticResponse =
StaticResponse(StatusCode::PERMANENT_REDIRECT); StaticResponse(StatusCode::PERMANENT_REDIRECT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::BadRequest()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::BadRequest()` instead")]
pub const HttpBadRequest: StaticResponse = StaticResponse(StatusCode::BAD_REQUEST); pub const HttpBadRequest: StaticResponse = StaticResponse(StatusCode::BAD_REQUEST);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Unauthorized()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::Unauthorized()` instead")]
pub const HttpUnauthorized: StaticResponse = StaticResponse(StatusCode::UNAUTHORIZED); pub const HttpUnauthorized: StaticResponse = StaticResponse(StatusCode::UNAUTHORIZED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PaymentRequired()` instead")] #[deprecated(since = "0.5.0",
pub const HttpPaymentRequired: StaticResponse = StaticResponse(StatusCode::PAYMENT_REQUIRED); note = "please use `HttpResponse::PaymentRequired()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::Forbidden()` instead")] pub const HttpPaymentRequired: StaticResponse =
StaticResponse(StatusCode::PAYMENT_REQUIRED);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Forbidden()` instead")]
pub const HttpForbidden: StaticResponse = StaticResponse(StatusCode::FORBIDDEN); pub const HttpForbidden: StaticResponse = StaticResponse(StatusCode::FORBIDDEN);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotFound()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::NotFound()` instead")]
pub const HttpNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND); pub const HttpNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND);
#[deprecated(since="0.5.0", note="please use `HttpResponse::MethodNotAllowed()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::MethodNotAllowed()` instead")]
pub const HttpMethodNotAllowed: StaticResponse = pub const HttpMethodNotAllowed: StaticResponse =
StaticResponse(StatusCode::METHOD_NOT_ALLOWED); StaticResponse(StatusCode::METHOD_NOT_ALLOWED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotAcceptable()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::NotAcceptable()` instead")]
pub const HttpNotAcceptable: StaticResponse = StaticResponse(StatusCode::NOT_ACCEPTABLE); pub const HttpNotAcceptable: StaticResponse = StaticResponse(StatusCode::NOT_ACCEPTABLE);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::ProxyAuthenticationRequired()` instead")] note = "please use `HttpResponse::ProxyAuthenticationRequired()` instead")]
pub const HttpProxyAuthenticationRequired: StaticResponse = pub const HttpProxyAuthenticationRequired: StaticResponse =
StaticResponse(StatusCode::PROXY_AUTHENTICATION_REQUIRED); StaticResponse(StatusCode::PROXY_AUTHENTICATION_REQUIRED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::RequestTimeout()` instead")] #[deprecated(since = "0.5.0",
pub const HttpRequestTimeout: StaticResponse = StaticResponse(StatusCode::REQUEST_TIMEOUT); note = "please use `HttpResponse::RequestTimeout()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::Conflict()` instead")] pub const HttpRequestTimeout: StaticResponse =
StaticResponse(StatusCode::REQUEST_TIMEOUT);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::Conflict()` instead")]
pub const HttpConflict: StaticResponse = StaticResponse(StatusCode::CONFLICT); pub const HttpConflict: StaticResponse = StaticResponse(StatusCode::CONFLICT);
#[deprecated(since="0.5.0", note="please use `HttpResponse::Gone()` instead")] #[deprecated(since = "0.5.0", note = "please use `HttpResponse::Gone()` instead")]
pub const HttpGone: StaticResponse = StaticResponse(StatusCode::GONE); pub const HttpGone: StaticResponse = StaticResponse(StatusCode::GONE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::LengthRequired()` instead")] #[deprecated(since = "0.5.0",
pub const HttpLengthRequired: StaticResponse = StaticResponse(StatusCode::LENGTH_REQUIRED); note = "please use `HttpResponse::LengthRequired()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::PreconditionFailed()` instead")] pub const HttpLengthRequired: StaticResponse =
StaticResponse(StatusCode::LENGTH_REQUIRED);
#[deprecated(since = "0.5.0",
note = "please use `HttpResponse::PreconditionFailed()` instead")]
pub const HttpPreconditionFailed: StaticResponse = pub const HttpPreconditionFailed: StaticResponse =
StaticResponse(StatusCode::PRECONDITION_FAILED); StaticResponse(StatusCode::PRECONDITION_FAILED);
#[deprecated(since="0.5.0", note="please use `HttpResponse::PayloadTooLarge()` instead")] #[deprecated(since = "0.5.0",
pub const HttpPayloadTooLarge: StaticResponse = StaticResponse(StatusCode::PAYLOAD_TOO_LARGE); note = "please use `HttpResponse::PayloadTooLarge()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::UriTooLong()` instead")] pub const HttpPayloadTooLarge: StaticResponse =
StaticResponse(StatusCode::PAYLOAD_TOO_LARGE);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::UriTooLong()` instead")]
pub const HttpUriTooLong: StaticResponse = StaticResponse(StatusCode::URI_TOO_LONG); pub const HttpUriTooLong: StaticResponse = StaticResponse(StatusCode::URI_TOO_LONG);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::UnsupportedMediaType()` instead")] note = "please use `HttpResponse::UnsupportedMediaType()` instead")]
pub const HttpUnsupportedMediaType: StaticResponse = pub const HttpUnsupportedMediaType: StaticResponse =
StaticResponse(StatusCode::UNSUPPORTED_MEDIA_TYPE); StaticResponse(StatusCode::UNSUPPORTED_MEDIA_TYPE);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::RangeNotSatisfiable()` instead")] note = "please use `HttpResponse::RangeNotSatisfiable()` instead")]
pub const HttpRangeNotSatisfiable: StaticResponse = pub const HttpRangeNotSatisfiable: StaticResponse =
StaticResponse(StatusCode::RANGE_NOT_SATISFIABLE); StaticResponse(StatusCode::RANGE_NOT_SATISFIABLE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::ExpectationFailed()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ExpectationFailed()` instead")]
pub const HttpExpectationFailed: StaticResponse = pub const HttpExpectationFailed: StaticResponse =
StaticResponse(StatusCode::EXPECTATION_FAILED); StaticResponse(StatusCode::EXPECTATION_FAILED);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::InternalServerError()` instead")] note = "please use `HttpResponse::InternalServerError()` instead")]
pub const HttpInternalServerError: StaticResponse = pub const HttpInternalServerError: StaticResponse =
StaticResponse(StatusCode::INTERNAL_SERVER_ERROR); StaticResponse(StatusCode::INTERNAL_SERVER_ERROR);
#[deprecated(since="0.5.0", note="please use `HttpResponse::NotImplemented()` instead")] #[deprecated(since = "0.5.0",
pub const HttpNotImplemented: StaticResponse = StaticResponse(StatusCode::NOT_IMPLEMENTED); note = "please use `HttpResponse::NotImplemented()` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse::BadGateway()` instead")] pub const HttpNotImplemented: StaticResponse =
StaticResponse(StatusCode::NOT_IMPLEMENTED);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse::BadGateway()` instead")]
pub const HttpBadGateway: StaticResponse = StaticResponse(StatusCode::BAD_GATEWAY); pub const HttpBadGateway: StaticResponse = StaticResponse(StatusCode::BAD_GATEWAY);
#[deprecated(since="0.5.0", note="please use `HttpResponse::ServiceUnavailable()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::ServiceUnavailable()` instead")]
pub const HttpServiceUnavailable: StaticResponse = pub const HttpServiceUnavailable: StaticResponse =
StaticResponse(StatusCode::SERVICE_UNAVAILABLE); StaticResponse(StatusCode::SERVICE_UNAVAILABLE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::GatewayTimeout()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::GatewayTimeout()` instead")]
pub const HttpGatewayTimeout: StaticResponse = pub const HttpGatewayTimeout: StaticResponse =
StaticResponse(StatusCode::GATEWAY_TIMEOUT); StaticResponse(StatusCode::GATEWAY_TIMEOUT);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::VersionNotSupported()` instead")] note = "please use `HttpResponse::VersionNotSupported()` instead")]
pub const HttpVersionNotSupported: StaticResponse = pub const HttpVersionNotSupported: StaticResponse =
StaticResponse(StatusCode::HTTP_VERSION_NOT_SUPPORTED); StaticResponse(StatusCode::HTTP_VERSION_NOT_SUPPORTED);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::VariantAlsoNegotiates()` instead")] note = "please use `HttpResponse::VariantAlsoNegotiates()` instead")]
pub const HttpVariantAlsoNegotiates: StaticResponse = pub const HttpVariantAlsoNegotiates: StaticResponse =
StaticResponse(StatusCode::VARIANT_ALSO_NEGOTIATES); StaticResponse(StatusCode::VARIANT_ALSO_NEGOTIATES);
#[deprecated(since="0.5.0", #[deprecated(since = "0.5.0",
note="please use `HttpResponse::InsufficientStorage()` instead")] note = "please use `HttpResponse::InsufficientStorage()` instead")]
pub const HttpInsufficientStorage: StaticResponse = pub const HttpInsufficientStorage: StaticResponse =
StaticResponse(StatusCode::INSUFFICIENT_STORAGE); StaticResponse(StatusCode::INSUFFICIENT_STORAGE);
#[deprecated(since="0.5.0", note="please use `HttpResponse::LoopDetected()` instead")] #[deprecated(since = "0.5.0",
note = "please use `HttpResponse::LoopDetected()` instead")]
pub const HttpLoopDetected: StaticResponse = StaticResponse(StatusCode::LOOP_DETECTED); pub const HttpLoopDetected: StaticResponse = StaticResponse(StatusCode::LOOP_DETECTED);
#[deprecated(since = "0.5.0", note = "please use `HttpResponse` instead")]
#[deprecated(since="0.5.0", note="please use `HttpResponse` instead")]
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub struct StaticResponse(StatusCode); pub struct StaticResponse(StatusCode);
@ -186,14 +214,17 @@ macro_rules! STATIC_RESP {
pub fn $name() -> HttpResponseBuilder { pub fn $name() -> HttpResponseBuilder {
HttpResponse::build($status) HttpResponse::build($status)
} }
} };
} }
impl HttpResponse { impl HttpResponse {
STATIC_RESP!(Ok, StatusCode::OK); STATIC_RESP!(Ok, StatusCode::OK);
STATIC_RESP!(Created, StatusCode::CREATED); STATIC_RESP!(Created, StatusCode::CREATED);
STATIC_RESP!(Accepted, StatusCode::ACCEPTED); STATIC_RESP!(Accepted, StatusCode::ACCEPTED);
STATIC_RESP!(NonAuthoritativeInformation, StatusCode::NON_AUTHORITATIVE_INFORMATION); STATIC_RESP!(
NonAuthoritativeInformation,
StatusCode::NON_AUTHORITATIVE_INFORMATION
);
STATIC_RESP!(NoContent, StatusCode::NO_CONTENT); STATIC_RESP!(NoContent, StatusCode::NO_CONTENT);
STATIC_RESP!(ResetContent, StatusCode::RESET_CONTENT); STATIC_RESP!(ResetContent, StatusCode::RESET_CONTENT);
@ -218,7 +249,10 @@ impl HttpResponse {
STATIC_RESP!(Forbidden, StatusCode::FORBIDDEN); STATIC_RESP!(Forbidden, StatusCode::FORBIDDEN);
STATIC_RESP!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED); STATIC_RESP!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED);
STATIC_RESP!(NotAcceptable, StatusCode::NOT_ACCEPTABLE); STATIC_RESP!(NotAcceptable, StatusCode::NOT_ACCEPTABLE);
STATIC_RESP!(ProxyAuthenticationRequired, StatusCode::PROXY_AUTHENTICATION_REQUIRED); STATIC_RESP!(
ProxyAuthenticationRequired,
StatusCode::PROXY_AUTHENTICATION_REQUIRED
);
STATIC_RESP!(RequestTimeout, StatusCode::REQUEST_TIMEOUT); STATIC_RESP!(RequestTimeout, StatusCode::REQUEST_TIMEOUT);
STATIC_RESP!(Conflict, StatusCode::CONFLICT); STATIC_RESP!(Conflict, StatusCode::CONFLICT);
STATIC_RESP!(Gone, StatusCode::GONE); STATIC_RESP!(Gone, StatusCode::GONE);
@ -226,7 +260,10 @@ impl HttpResponse {
STATIC_RESP!(PreconditionFailed, StatusCode::PRECONDITION_FAILED); STATIC_RESP!(PreconditionFailed, StatusCode::PRECONDITION_FAILED);
STATIC_RESP!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE); STATIC_RESP!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE);
STATIC_RESP!(UriTooLong, StatusCode::URI_TOO_LONG); STATIC_RESP!(UriTooLong, StatusCode::URI_TOO_LONG);
STATIC_RESP!(UnsupportedMediaType, StatusCode::UNSUPPORTED_MEDIA_TYPE); STATIC_RESP!(
UnsupportedMediaType,
StatusCode::UNSUPPORTED_MEDIA_TYPE
);
STATIC_RESP!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE); STATIC_RESP!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE);
STATIC_RESP!(ExpectationFailed, StatusCode::EXPECTATION_FAILED); STATIC_RESP!(ExpectationFailed, StatusCode::EXPECTATION_FAILED);
@ -235,16 +272,22 @@ impl HttpResponse {
STATIC_RESP!(BadGateway, StatusCode::BAD_GATEWAY); STATIC_RESP!(BadGateway, StatusCode::BAD_GATEWAY);
STATIC_RESP!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE); STATIC_RESP!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE);
STATIC_RESP!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT); STATIC_RESP!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT);
STATIC_RESP!(VersionNotSupported, StatusCode::HTTP_VERSION_NOT_SUPPORTED); STATIC_RESP!(
STATIC_RESP!(VariantAlsoNegotiates, StatusCode::VARIANT_ALSO_NEGOTIATES); VersionNotSupported,
StatusCode::HTTP_VERSION_NOT_SUPPORTED
);
STATIC_RESP!(
VariantAlsoNegotiates,
StatusCode::VARIANT_ALSO_NEGOTIATES
);
STATIC_RESP!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE); STATIC_RESP!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE);
STATIC_RESP!(LoopDetected, StatusCode::LOOP_DETECTED); STATIC_RESP!(LoopDetected, StatusCode::LOOP_DETECTED);
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Body, HttpBadRequest, HttpOk, HttpResponse};
use http::StatusCode; use http::StatusCode;
use super::{HttpOk, HttpBadRequest, Body, HttpResponse};
#[test] #[test]
fn test_build() { fn test_build() {

View File

@ -1,45 +1,45 @@
use std::str;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Future, Stream, Poll};
use http_range::HttpRange;
use serde::de::DeserializeOwned;
use mime::Mime;
use serde_urlencoded;
use encoding::all::UTF_8;
use encoding::EncodingRef; use encoding::EncodingRef;
use encoding::types::{Encoding, DecoderTrap}; use encoding::all::UTF_8;
use encoding::label::encoding_from_whatwg_label; use encoding::label::encoding_from_whatwg_label;
use encoding::types::{DecoderTrap, Encoding};
use futures::{Future, Poll, Stream};
use http::{header, HeaderMap}; use http::{header, HeaderMap};
use http_range::HttpRange;
use mime::Mime;
use serde::de::DeserializeOwned;
use serde_urlencoded;
use std::str;
use json::JsonBody; use error::{ContentTypeError, HttpRangeError, ParseError, PayloadError, UrlencodedError};
use header::Header; use header::Header;
use json::JsonBody;
use multipart::Multipart; use multipart::Multipart;
use error::{ParseError, ContentTypeError,
HttpRangeError, PayloadError, UrlencodedError};
/// Trait that implements general purpose operations on http messages /// Trait that implements general purpose operations on http messages
pub trait HttpMessage { pub trait HttpMessage {
/// Read the message headers. /// Read the message headers.
fn headers(&self) -> &HeaderMap; fn headers(&self) -> &HeaderMap;
#[doc(hidden)] #[doc(hidden)]
/// Get a header /// Get a header
fn get_header<H: Header>(&self) -> Option<H> where Self: Sized { fn get_header<H: Header>(&self) -> Option<H>
where
Self: Sized,
{
if self.headers().contains_key(H::name()) { if self.headers().contains_key(H::name()) {
H::parse(self).ok() H::parse(self).ok()
} else { } else {
None None
} }
} }
/// Read the request content type. If request does not contain /// Read the request content type. If request does not contain
/// *Content-Type* header, empty str get returned. /// *Content-Type* header, empty str get returned.
fn content_type(&self) -> &str { fn content_type(&self) -> &str {
if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) {
if let Ok(content_type) = content_type.to_str() { if let Ok(content_type) = content_type.to_str() {
return content_type.split(';').next().unwrap().trim() return content_type.split(';').next().unwrap().trim();
} }
} }
"" ""
@ -73,7 +73,7 @@ pub trait HttpMessage {
Err(_) => Err(ContentTypeError::ParseError), Err(_) => Err(ContentTypeError::ParseError),
}; };
} else { } else {
return Err(ContentTypeError::ParseError) return Err(ContentTypeError::ParseError);
} }
} }
Ok(None) Ok(None)
@ -96,8 +96,10 @@ pub trait HttpMessage {
/// `size` is full size of response (file). /// `size` is full size of response (file).
fn range(&self, size: u64) -> Result<Vec<HttpRange>, HttpRangeError> { fn range(&self, size: u64) -> Result<Vec<HttpRange>, HttpRangeError> {
if let Some(range) = self.headers().get(header::RANGE) { if let Some(range) = self.headers().get(header::RANGE) {
HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) HttpRange::parse(
.map_err(|e| e.into()) unsafe { str::from_utf8_unchecked(range.as_bytes()) },
size,
).map_err(|e| e.into())
} else { } else {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -105,8 +107,9 @@ pub trait HttpMessage {
/// Load http message body. /// Load http message body.
/// ///
/// By default only 256Kb payload reads to a memory, then `PayloadError::Overflow` /// By default only 256Kb payload reads to a memory, then
/// get returned. Use `MessageBody::limit()` method to change upper limit. /// `PayloadError::Overflow` get returned. Use `MessageBody::limit()`
/// method to change upper limit.
/// ///
/// ## Server example /// ## Server example
/// ///
@ -131,14 +134,15 @@ pub trait HttpMessage {
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn body(self) -> MessageBody<Self> fn body(self) -> MessageBody<Self>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{ {
MessageBody::new(self) MessageBody::new(self)
} }
/// Parse `application/x-www-form-urlencoded` encoded request's body. /// Parse `application/x-www-form-urlencoded` encoded request's body.
/// Return `UrlEncoded` future. Form can be deserialized to any type that implements /// Return `UrlEncoded` future. Form can be deserialized to any type that
/// `Deserialize` trait from *serde*. /// implements `Deserialize` trait from *serde*.
/// ///
/// Returns error: /// Returns error:
/// ///
@ -167,7 +171,8 @@ pub trait HttpMessage {
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn urlencoded<T: DeserializeOwned>(self) -> UrlEncoded<Self, T> fn urlencoded<T: DeserializeOwned>(self) -> UrlEncoded<Self, T>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{ {
UrlEncoded::new(self) UrlEncoded::new(self)
} }
@ -205,7 +210,8 @@ pub trait HttpMessage {
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn json<T: DeserializeOwned>(self) -> JsonBody<Self, T> fn json<T: DeserializeOwned>(self) -> JsonBody<Self, T>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{ {
JsonBody::new(self) JsonBody::new(self)
} }
@ -247,7 +253,8 @@ pub trait HttpMessage {
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn multipart(self) -> Multipart<Self> fn multipart(self) -> Multipart<Self>
where Self: Stream<Item=Bytes, Error=PayloadError> + Sized where
Self: Stream<Item = Bytes, Error = PayloadError> + Sized,
{ {
let boundary = Multipart::boundary(self.headers()); let boundary = Multipart::boundary(self.headers());
Multipart::new(boundary, self) Multipart::new(boundary, self)
@ -258,11 +265,10 @@ pub trait HttpMessage {
pub struct MessageBody<T> { pub struct MessageBody<T> {
limit: usize, limit: usize,
req: Option<T>, req: Option<T>,
fut: Option<Box<Future<Item=Bytes, Error=PayloadError>>>, fut: Option<Box<Future<Item = Bytes, Error = PayloadError>>>,
} }
impl<T> MessageBody<T> { impl<T> MessageBody<T> {
/// Create `RequestBody` for request. /// Create `RequestBody` for request.
pub fn new(req: T) -> MessageBody<T> { pub fn new(req: T) -> MessageBody<T> {
MessageBody { MessageBody {
@ -280,7 +286,8 @@ impl<T> MessageBody<T> {
} }
impl<T> Future for MessageBody<T> impl<T> Future for MessageBody<T>
where T: HttpMessage + Stream<Item=Bytes, Error=PayloadError> + 'static where
T: HttpMessage + Stream<Item = Bytes, Error = PayloadError> + 'static,
{ {
type Item = Bytes; type Item = Bytes;
type Error = PayloadError; type Error = PayloadError;
@ -313,11 +320,14 @@ impl<T> Future for MessageBody<T>
Ok(body) Ok(body)
} }
}) })
.map(|body| body.freeze()) .map(|body| body.freeze()),
)); ));
} }
self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() self.fut
.as_mut()
.expect("UrlEncoded could not be used second time")
.poll()
} }
} }
@ -325,7 +335,7 @@ impl<T> Future for MessageBody<T>
pub struct UrlEncoded<T, U> { pub struct UrlEncoded<T, U> {
req: Option<T>, req: Option<T>,
limit: usize, limit: usize,
fut: Option<Box<Future<Item=U, Error=UrlencodedError>>>, fut: Option<Box<Future<Item = U, Error = UrlencodedError>>>,
} }
impl<T, U> UrlEncoded<T, U> { impl<T, U> UrlEncoded<T, U> {
@ -345,8 +355,9 @@ impl<T, U> UrlEncoded<T, U> {
} }
impl<T, U> Future for UrlEncoded<T, U> impl<T, U> Future for UrlEncoded<T, U>
where T: HttpMessage + Stream<Item=Bytes, Error=PayloadError> + 'static, where
U: DeserializeOwned + 'static T: HttpMessage + Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static,
{ {
type Item = U; type Item = U;
type Error = UrlencodedError; type Error = UrlencodedError;
@ -354,7 +365,7 @@ impl<T, U> Future for UrlEncoded<T, U>
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(req) = self.req.take() { if let Some(req) = self.req.take() {
if req.chunked().unwrap_or(false) { if req.chunked().unwrap_or(false) {
return Err(UrlencodedError::Chunked) return Err(UrlencodedError::Chunked);
} else if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { } else if let Some(len) = req.headers().get(header::CONTENT_LENGTH) {
if let Ok(s) = len.to_str() { if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() { if let Ok(len) = s.parse::<u64>() {
@ -362,18 +373,19 @@ impl<T, U> Future for UrlEncoded<T, U>
return Err(UrlencodedError::Overflow); return Err(UrlencodedError::Overflow);
} }
} else { } else {
return Err(UrlencodedError::UnknownLength) return Err(UrlencodedError::UnknownLength);
} }
} else { } else {
return Err(UrlencodedError::UnknownLength) return Err(UrlencodedError::UnknownLength);
} }
} }
// check content type // check content type
if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" { if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" {
return Err(UrlencodedError::ContentType) return Err(UrlencodedError::ContentType);
} }
let encoding = req.encoding().map_err(|_| UrlencodedError::ContentType)?; let encoding = req.encoding()
.map_err(|_| UrlencodedError::ContentType)?;
// future // future
let limit = self.limit; let limit = self.limit;
@ -392,7 +404,8 @@ impl<T, U> Future for UrlEncoded<T, U>
serde_urlencoded::from_bytes::<U>(&body) serde_urlencoded::from_bytes::<U>(&body)
.map_err(|_| UrlencodedError::Parse) .map_err(|_| UrlencodedError::Parse)
} else { } else {
let body = encoding.decode(&body, DecoderTrap::Strict) let body = encoding
.decode(&body, DecoderTrap::Strict)
.map_err(|_| UrlencodedError::Parse)?; .map_err(|_| UrlencodedError::Parse)?;
serde_urlencoded::from_str::<U>(&body) serde_urlencoded::from_str::<U>(&body)
.map_err(|_| UrlencodedError::Parse) .map_err(|_| UrlencodedError::Parse)
@ -401,19 +414,22 @@ impl<T, U> Future for UrlEncoded<T, U>
self.fut = Some(Box::new(fut)); self.fut = Some(Box::new(fut));
} }
self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() self.fut
.as_mut()
.expect("UrlEncoded could not be used second time")
.poll()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use mime;
use encoding::Encoding; use encoding::Encoding;
use encoding::all::ISO_8859_2; use encoding::all::ISO_8859_2;
use futures::Async; use futures::Async;
use http::{Method, Version, Uri}; use http::{Method, Uri, Version};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use mime;
use std::str::FromStr; use std::str::FromStr;
use test::TestRequest; use test::TestRequest;
@ -421,8 +437,9 @@ mod tests {
fn test_content_type() { fn test_content_type() {
let req = TestRequest::with_header("content-type", "text/plain").finish(); let req = TestRequest::with_header("content-type", "text/plain").finish();
assert_eq!(req.content_type(), "text/plain"); assert_eq!(req.content_type(), "text/plain");
let req = TestRequest::with_header( let req =
"content-type", "application/json; charset=utf=8").finish(); TestRequest::with_header("content-type", "application/json; charset=utf=8")
.finish();
assert_eq!(req.content_type(), "application/json"); assert_eq!(req.content_type(), "application/json");
let req = HttpRequest::default(); let req = HttpRequest::default();
assert_eq!(req.content_type(), ""); assert_eq!(req.content_type(), "");
@ -434,8 +451,9 @@ mod tests {
assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON)); assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON));
let req = HttpRequest::default(); let req = HttpRequest::default();
assert_eq!(req.mime_type().unwrap(), None); assert_eq!(req.mime_type().unwrap(), None);
let req = TestRequest::with_header( let req =
"content-type", "application/json; charset=utf-8").finish(); TestRequest::with_header("content-type", "application/json; charset=utf-8")
.finish();
let mt = req.mime_type().unwrap().unwrap(); let mt = req.mime_type().unwrap().unwrap();
assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8));
assert_eq!(mt.type_(), mime::APPLICATION); assert_eq!(mt.type_(), mime::APPLICATION);
@ -445,7 +463,9 @@ mod tests {
#[test] #[test]
fn test_mime_type_error() { fn test_mime_type_error() {
let req = TestRequest::with_header( let req = TestRequest::with_header(
"content-type", "applicationadfadsfasdflknadsfklnadsfjson").finish(); "content-type",
"applicationadfadsfasdflknadsfklnadsfjson",
).finish();
assert_eq!(Err(ContentTypeError::ParseError), req.mime_type()); assert_eq!(Err(ContentTypeError::ParseError), req.mime_type());
} }
@ -454,24 +474,32 @@ mod tests {
let req = HttpRequest::default(); let req = HttpRequest::default();
assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); assert_eq!(UTF_8.name(), req.encoding().unwrap().name());
let req = TestRequest::with_header( let req = TestRequest::with_header("content-type", "application/json").finish();
"content-type", "application/json").finish();
assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); assert_eq!(UTF_8.name(), req.encoding().unwrap().name());
let req = TestRequest::with_header( let req = TestRequest::with_header(
"content-type", "application/json; charset=ISO-8859-2").finish(); "content-type",
"application/json; charset=ISO-8859-2",
).finish();
assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name()); assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name());
} }
#[test] #[test]
fn test_encoding_error() { fn test_encoding_error() {
let req = TestRequest::with_header( let req = TestRequest::with_header("content-type", "applicatjson").finish();
"content-type", "applicatjson").finish(); assert_eq!(
assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err()); Some(ContentTypeError::ParseError),
req.encoding().err()
);
let req = TestRequest::with_header( let req = TestRequest::with_header(
"content-type", "application/json; charset=kkkttktk").finish(); "content-type",
assert_eq!(Some(ContentTypeError::UnknownEncoding), req.encoding().err()); "application/json; charset=kkkttktk",
).finish();
assert_eq!(
Some(ContentTypeError::UnknownEncoding),
req.encoding().err()
);
} }
#[test] #[test]
@ -495,17 +523,26 @@ mod tests {
let req = HttpRequest::default(); let req = HttpRequest::default();
assert!(!req.chunked().unwrap()); assert!(!req.chunked().unwrap());
let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); let req =
TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish();
assert!(req.chunked().unwrap()); assert!(req.chunked().unwrap());
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
let s = unsafe{str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())}; let s = unsafe {
str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())
};
headers.insert(header::TRANSFER_ENCODING, headers.insert(
header::HeaderValue::from_str(s).unwrap()); header::TRANSFER_ENCODING,
header::HeaderValue::from_str(s).unwrap(),
);
let req = HttpRequest::new( let req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Method::GET,
Version::HTTP_11, headers, None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert!(req.chunked().is_err()); assert!(req.chunked().is_err());
} }
@ -540,51 +577,75 @@ mod tests {
#[test] #[test]
fn test_urlencoded_error() { fn test_urlencoded_error() {
let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); let req =
assert_eq!(req.urlencoded::<Info>() TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish();
.poll().err().unwrap(), UrlencodedError::Chunked); assert_eq!(
req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::Chunked
);
let req = TestRequest::with_header( let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded") header::CONTENT_TYPE,
.header(header::CONTENT_LENGTH, "xxxx") "application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "xxxx")
.finish(); .finish();
assert_eq!(req.urlencoded::<Info>() assert_eq!(
.poll().err().unwrap(), UrlencodedError::UnknownLength); req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::UnknownLength
);
let req = TestRequest::with_header( let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded") header::CONTENT_TYPE,
.header(header::CONTENT_LENGTH, "1000000") "application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "1000000")
.finish(); .finish();
assert_eq!(req.urlencoded::<Info>() assert_eq!(
.poll().err().unwrap(), UrlencodedError::Overflow); req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::Overflow
);
let req = TestRequest::with_header( let req = TestRequest::with_header(header::CONTENT_TYPE, "text/plain")
header::CONTENT_TYPE, "text/plain")
.header(header::CONTENT_LENGTH, "10") .header(header::CONTENT_LENGTH, "10")
.finish(); .finish();
assert_eq!(req.urlencoded::<Info>() assert_eq!(
.poll().err().unwrap(), UrlencodedError::ContentType); req.urlencoded::<Info>().poll().err().unwrap(),
UrlencodedError::ContentType
);
} }
#[test] #[test]
fn test_urlencoded() { fn test_urlencoded() {
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded") header::CONTENT_TYPE,
.header(header::CONTENT_LENGTH, "11") "application/x-www-form-urlencoded",
).header(header::CONTENT_LENGTH, "11")
.finish(); .finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
let result = req.urlencoded::<Info>().poll().ok().unwrap(); let result = req.urlencoded::<Info>().poll().ok().unwrap();
assert_eq!(result, Async::Ready(Info{hello: "world".to_owned()})); assert_eq!(
result,
Async::Ready(Info {
hello: "world".to_owned()
})
);
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded; charset=utf-8") header::CONTENT_TYPE,
.header(header::CONTENT_LENGTH, "11") "application/x-www-form-urlencoded; charset=utf-8",
).header(header::CONTENT_LENGTH, "11")
.finish(); .finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut()
.unread_data(Bytes::from_static(b"hello=world"));
let result = req.urlencoded().poll().ok().unwrap(); let result = req.urlencoded().poll().ok().unwrap();
assert_eq!(result, Async::Ready(Info{hello: "world".to_owned()})); assert_eq!(
result,
Async::Ready(Info {
hello: "world".to_owned()
})
);
} }
#[test] #[test]
@ -602,14 +663,16 @@ mod tests {
} }
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.payload_mut().unread_data(Bytes::from_static(b"test")); req.payload_mut()
.unread_data(Bytes::from_static(b"test"));
match req.body().poll().ok().unwrap() { match req.body().poll().ok().unwrap() {
Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")), Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.payload_mut().unread_data(Bytes::from_static(b"11111111111111")); req.payload_mut()
.unread_data(Bytes::from_static(b"11111111111111"));
match req.body().limit(5).poll().err().unwrap() { match req.body().limit(5).poll().err().unwrap() {
PayloadError::Overflow => (), PayloadError::Overflow => (),
_ => unreachable!("error"), _ => unreachable!("error"),

View File

@ -1,30 +1,29 @@
//! HTTP Request message related code. //! HTTP Request message related code.
use std::{io, cmp, str, fmt, mem};
use std::rc::Rc;
use std::net::SocketAddr;
use std::borrow::Cow;
use bytes::Bytes; use bytes::Bytes;
use cookie::Cookie; use cookie::Cookie;
use futures::{Async, Stream, Poll};
use futures::future::{FutureResult, result};
use futures_cpupool::CpuPool;
use failure; use failure;
use url::{Url, form_urlencoded}; use futures::future::{result, FutureResult};
use http::{header, Uri, Method, Version, HeaderMap, Extensions, StatusCode}; use futures::{Async, Poll, Stream};
use tokio_io::AsyncRead; use futures_cpupool::CpuPool;
use http::{header, Extensions, HeaderMap, Method, StatusCode, Uri, Version};
use percent_encoding::percent_decode; use percent_encoding::percent_decode;
use std::borrow::Cow;
use std::net::SocketAddr;
use std::rc::Rc;
use std::{cmp, fmt, io, mem, str};
use tokio_io::AsyncRead;
use url::{form_urlencoded, Url};
use body::Body; use body::Body;
use info::ConnectionInfo; use error::{CookieParseError, Error, PayloadError, UrlGenerationError};
use param::Params;
use router::{Router, Resource};
use payload::Payload;
use handler::FromRequest; use handler::FromRequest;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httpresponse::{HttpResponse, HttpResponseBuilder}; use httpresponse::{HttpResponse, HttpResponseBuilder};
use info::ConnectionInfo;
use param::Params;
use payload::Payload;
use router::{Resource, Router};
use server::helpers::SharedHttpInnerMessage; use server::helpers::SharedHttpInnerMessage;
use error::{Error, UrlGenerationError, CookieParseError, PayloadError};
pub struct HttpInnerMessage { pub struct HttpInnerMessage {
pub version: Version, pub version: Version,
@ -42,14 +41,13 @@ pub struct HttpInnerMessage {
resource: RouterResource, resource: RouterResource,
} }
#[derive(Debug, Copy, Clone,PartialEq)] #[derive(Debug, Copy, Clone, PartialEq)]
enum RouterResource { enum RouterResource {
Notset, Notset,
Normal(u16), Normal(u16),
} }
impl Default for HttpInnerMessage { impl Default for HttpInnerMessage {
fn default() -> HttpInnerMessage { fn default() -> HttpInnerMessage {
HttpInnerMessage { HttpInnerMessage {
method: Method::GET, method: Method::GET,
@ -70,7 +68,6 @@ impl Default for HttpInnerMessage {
} }
impl HttpInnerMessage { impl HttpInnerMessage {
/// Checks if a connection should be kept alive. /// Checks if a connection should be kept alive.
#[inline] #[inline]
pub fn keep_alive(&self) -> bool { pub fn keep_alive(&self) -> bool {
@ -79,8 +76,8 @@ impl HttpInnerMessage {
if self.version == Version::HTTP_10 && conn.contains("keep-alive") { if self.version == Version::HTTP_10 && conn.contains("keep-alive") {
true true
} else { } else {
self.version == Version::HTTP_11 && self.version == Version::HTTP_11
!(conn.contains("close") || conn.contains("upgrade")) && !(conn.contains("close") || conn.contains("upgrade"))
} }
} else { } else {
false false
@ -105,21 +102,20 @@ impl HttpInnerMessage {
} }
} }
lazy_static!{ lazy_static! {
static ref RESOURCE: Resource = Resource::unset(); static ref RESOURCE: Resource = Resource::unset();
} }
/// An HTTP Request /// An HTTP Request
pub struct HttpRequest<S=()>(SharedHttpInnerMessage, Option<Rc<S>>, Option<Router>); pub struct HttpRequest<S = ()>(SharedHttpInnerMessage, Option<Rc<S>>, Option<Router>);
impl HttpRequest<()> { impl HttpRequest<()> {
/// Construct a new Request. /// Construct a new Request.
#[inline] #[inline]
pub fn new(method: Method, uri: Uri, pub fn new(
version: Version, headers: HeaderMap, payload: Option<Payload>) method: Method, uri: Uri, version: Version, headers: HeaderMap,
-> HttpRequest payload: Option<Payload>,
{ ) -> HttpRequest {
HttpRequest( HttpRequest(
SharedHttpInnerMessage::from_message(HttpInnerMessage { SharedHttpInnerMessage::from_message(HttpInnerMessage {
method, method,
@ -142,7 +138,7 @@ impl HttpRequest<()> {
} }
#[inline(always)] #[inline(always)]
#[cfg_attr(feature="cargo-clippy", allow(inline_always))] #[cfg_attr(feature = "cargo-clippy", allow(inline_always))]
pub(crate) fn from_message(msg: SharedHttpInnerMessage) -> HttpRequest { pub(crate) fn from_message(msg: SharedHttpInnerMessage) -> HttpRequest {
HttpRequest(msg, None, None) HttpRequest(msg, None, None)
} }
@ -154,7 +150,6 @@ impl HttpRequest<()> {
} }
} }
impl<S> HttpMessage for HttpRequest<S> { impl<S> HttpMessage for HttpRequest<S> {
#[inline] #[inline]
fn headers(&self) -> &HeaderMap { fn headers(&self) -> &HeaderMap {
@ -163,7 +158,6 @@ impl<S> HttpMessage for HttpRequest<S> {
} }
impl<S> HttpRequest<S> { impl<S> HttpRequest<S> {
#[inline] #[inline]
/// Construct new http request with state. /// Construct new http request with state.
pub fn change_state<NS>(&self, state: Rc<NS>) -> HttpRequest<NS> { pub fn change_state<NS>(&self, state: Rc<NS>) -> HttpRequest<NS> {
@ -211,8 +205,10 @@ impl<S> HttpRequest<S> {
#[inline] #[inline]
#[doc(hidden)] #[doc(hidden)]
pub fn cpu_pool(&self) -> &CpuPool { pub fn cpu_pool(&self) -> &CpuPool {
self.router().expect("HttpRequest has to have Router instance") self.router()
.server_settings().cpu_pool() .expect("HttpRequest has to have Router instance")
.server_settings()
.cpu_pool()
} }
/// Create http response /// Create http response
@ -235,12 +231,18 @@ impl<S> HttpRequest<S> {
#[doc(hidden)] #[doc(hidden)]
pub fn prefix_len(&self) -> usize { pub fn prefix_len(&self) -> usize {
if let Some(router) = self.router() { router.prefix().len() } else { 0 } if let Some(router) = self.router() {
router.prefix().len()
} else {
0
}
} }
/// Read the Request Uri. /// Read the Request Uri.
#[inline] #[inline]
pub fn uri(&self) -> &Uri { &self.as_ref().uri } pub fn uri(&self) -> &Uri {
&self.as_ref().uri
}
/// Returns mutable the Request Uri. /// Returns mutable the Request Uri.
/// ///
@ -252,7 +254,9 @@ impl<S> HttpRequest<S> {
/// Read the Request method. /// Read the Request method.
#[inline] #[inline]
pub fn method(&self) -> &Method { &self.as_ref().method } pub fn method(&self) -> &Method {
&self.as_ref().method
}
/// Read the Request Version. /// Read the Request Version.
#[inline] #[inline]
@ -277,14 +281,16 @@ impl<S> HttpRequest<S> {
/// Percent decoded path of this Request. /// Percent decoded path of this Request.
#[inline] #[inline]
pub fn path_decoded(&self) -> Cow<str> { pub fn path_decoded(&self) -> Cow<str> {
percent_decode(self.uri().path().as_bytes()).decode_utf8().unwrap() percent_decode(self.uri().path().as_bytes())
.decode_utf8()
.unwrap()
} }
/// Get *ConnectionInfo* for correct request. /// Get *ConnectionInfo* for correct request.
pub fn connection_info(&self) -> &ConnectionInfo { pub fn connection_info(&self) -> &ConnectionInfo {
if self.as_ref().info.is_none() { if self.as_ref().info.is_none() {
let info: ConnectionInfo<'static> = unsafe{ let info: ConnectionInfo<'static> =
mem::transmute(ConnectionInfo::new(self))}; unsafe { mem::transmute(ConnectionInfo::new(self)) };
self.as_mut().info = Some(info); self.as_mut().info = Some(info);
} }
self.as_ref().info.as_ref().unwrap() self.as_ref().info.as_ref().unwrap()
@ -310,9 +316,12 @@ impl<S> HttpRequest<S> {
/// .finish(); /// .finish();
/// } /// }
/// ``` /// ```
pub fn url_for<U, I>(&self, name: &str, elements: U) -> Result<Url, UrlGenerationError> pub fn url_for<U, I>(
where U: IntoIterator<Item=I>, &self, name: &str, elements: U
I: AsRef<str>, ) -> Result<Url, UrlGenerationError>
where
U: IntoIterator<Item = I>,
I: AsRef<str>,
{ {
if self.router().is_none() { if self.router().is_none() {
Err(UrlGenerationError::RouterNotAvailable) Err(UrlGenerationError::RouterNotAvailable)
@ -320,7 +329,12 @@ impl<S> HttpRequest<S> {
let path = self.router().unwrap().resource_path(name, elements)?; let path = self.router().unwrap().resource_path(name, elements)?;
if path.starts_with('/') { if path.starts_with('/') {
let conn = self.connection_info(); let conn = self.connection_info();
Ok(Url::parse(&format!("{}://{}{}", conn.scheme(), conn.host(), path))?) Ok(Url::parse(&format!(
"{}://{}{}",
conn.scheme(),
conn.host(),
path
))?)
} else { } else {
Ok(Url::parse(&path)?) Ok(Url::parse(&path)?)
} }
@ -338,7 +352,7 @@ impl<S> HttpRequest<S> {
pub fn resource(&self) -> &Resource { pub fn resource(&self) -> &Resource {
if let Some(ref router) = self.2 { if let Some(ref router) = self.2 {
if let RouterResource::Normal(idx) = self.as_ref().resource { if let RouterResource::Normal(idx) = self.as_ref().resource {
return router.get_resource(idx as usize) return router.get_resource(idx as usize);
} }
} }
&*RESOURCE &*RESOURCE
@ -353,7 +367,8 @@ impl<S> HttpRequest<S> {
/// Peer address is actual socket address, if proxy is used in front of /// Peer address is actual socket address, if proxy is used in front of
/// actix http server, then peer address would be address of this proxy. /// actix http server, then peer address would be address of this proxy.
/// ///
/// To get client connection information `connection_info()` method should be used. /// To get client connection information `connection_info()` method should
/// be used.
#[inline] #[inline]
pub fn peer_addr(&self) -> Option<&SocketAddr> { pub fn peer_addr(&self) -> Option<&SocketAddr> {
self.as_ref().addr.as_ref() self.as_ref().addr.as_ref()
@ -368,13 +383,14 @@ impl<S> HttpRequest<S> {
/// Params is a container for url query parameters. /// Params is a container for url query parameters.
pub fn query(&self) -> &Params { pub fn query(&self) -> &Params {
if !self.as_ref().query_loaded { if !self.as_ref().query_loaded {
let params: &mut Params = unsafe{ mem::transmute(&mut self.as_mut().query) }; let params: &mut Params =
unsafe { mem::transmute(&mut self.as_mut().query) };
self.as_mut().query_loaded = true; self.as_mut().query_loaded = true;
for (key, val) in form_urlencoded::parse(self.query_string().as_ref()) { for (key, val) in form_urlencoded::parse(self.query_string().as_ref()) {
params.add(key, val); params.add(key, val);
} }
} }
unsafe{ mem::transmute(&self.as_ref().query) } unsafe { mem::transmute(&self.as_ref().query) }
} }
/// The query string in the URL. /// The query string in the URL.
@ -412,7 +428,7 @@ impl<S> HttpRequest<S> {
if let Ok(cookies) = self.cookies() { if let Ok(cookies) = self.cookies() {
for cookie in cookies { for cookie in cookies {
if cookie.name() == name { if cookie.name() == name {
return Some(cookie) return Some(cookie);
} }
} }
} }
@ -423,16 +439,17 @@ impl<S> HttpRequest<S> {
/// ///
/// Params is a container for url parameters. /// Params is a container for url parameters.
/// Route supports glob patterns: * for a single wildcard segment and :param /// Route supports glob patterns: * for a single wildcard segment and :param
/// for matching storing that segment of the request url in the Params object. /// for matching storing that segment of the request url in the Params
/// object.
#[inline] #[inline]
pub fn match_info(&self) -> &Params { pub fn match_info(&self) -> &Params {
unsafe{ mem::transmute(&self.as_ref().params) } unsafe { mem::transmute(&self.as_ref().params) }
} }
/// Get mutable reference to request's Params. /// Get mutable reference to request's Params.
#[inline] #[inline]
pub fn match_info_mut(&mut self) -> &mut Params { pub fn match_info_mut(&mut self) -> &mut Params {
unsafe{ mem::transmute(&mut self.as_mut().params) } unsafe { mem::transmute(&mut self.as_mut().params) }
} }
/// Checks if a connection should be kept alive. /// Checks if a connection should be kept alive.
@ -444,7 +461,7 @@ impl<S> HttpRequest<S> {
pub(crate) fn upgrade(&self) -> bool { pub(crate) fn upgrade(&self) -> bool {
if let Some(conn) = self.as_ref().headers.get(header::CONNECTION) { if let Some(conn) = self.as_ref().headers.get(header::CONNECTION) {
if let Ok(s) = conn.to_str() { if let Ok(s) = conn.to_str() {
return s.to_lowercase().contains("upgrade") return s.to_lowercase().contains("upgrade");
} }
} }
self.as_ref().method == Method::CONNECT self.as_ref().method == Method::CONNECT
@ -479,7 +496,6 @@ impl<S> HttpRequest<S> {
} }
impl Default for HttpRequest<()> { impl Default for HttpRequest<()> {
/// Construct default request /// Construct default request
fn default() -> HttpRequest { fn default() -> HttpRequest {
HttpRequest(SharedHttpInnerMessage::default(), None, None) HttpRequest(SharedHttpInnerMessage::default(), None, None)
@ -492,8 +508,7 @@ impl<S> Clone for HttpRequest<S> {
} }
} }
impl<S: 'static> FromRequest<S> for HttpRequest<S> impl<S: 'static> FromRequest<S> for HttpRequest<S> {
{
type Config = (); type Config = ();
type Result = FutureResult<Self, Error>; type Result = FutureResult<Self, Error>;
@ -540,10 +555,13 @@ impl<S> io::Read for HttpRequest<S> {
} }
} }
Ok(Async::Ready(None)) => Ok(0), Ok(Async::Ready(None)) => Ok(0),
Ok(Async::NotReady) => Ok(Async::NotReady) => {
Err(io::Error::new(io::ErrorKind::WouldBlock, "Not ready")), Err(io::Error::new(io::ErrorKind::WouldBlock, "Not ready"))
Err(e) => }
Err(io::Error::new(io::ErrorKind::Other, failure::Error::from(e).compat())), Err(e) => Err(io::Error::new(
io::ErrorKind::Other,
failure::Error::from(e).compat(),
)),
} }
} else { } else {
Ok(0) Ok(0)
@ -556,8 +574,12 @@ impl<S> AsyncRead for HttpRequest<S> {}
impl<S> fmt::Debug for HttpRequest<S> { impl<S> fmt::Debug for HttpRequest<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!( let res = writeln!(
f, "\nHttpRequest {:?} {}:{}", f,
self.as_ref().version, self.as_ref().method, self.path_decoded()); "\nHttpRequest {:?} {}:{}",
self.as_ref().version,
self.as_ref().method,
self.path_decoded()
);
if !self.query_string().is_empty() { if !self.query_string().is_empty() {
let _ = writeln!(f, " query: ?{:?}", self.query_string()); let _ = writeln!(f, " query: ?{:?}", self.query_string());
} }
@ -575,11 +597,11 @@ impl<S> fmt::Debug for HttpRequest<S> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use http::{Uri, HttpTryFrom}; use http::{HttpTryFrom, Uri};
use router::Resource;
use resource::ResourceHandler; use resource::ResourceHandler;
use test::TestRequest; use router::Resource;
use server::ServerSettings; use server::ServerSettings;
use test::TestRequest;
#[test] #[test]
fn test_debug() { fn test_debug() {
@ -652,12 +674,19 @@ mod tests {
#[test] #[test]
fn test_url_for() { fn test_url_for() {
let req2 = HttpRequest::default(); let req2 = HttpRequest::default();
assert_eq!(req2.url_for("unknown", &["test"]), assert_eq!(
Err(UrlGenerationError::RouterNotAvailable)); req2.url_for("unknown", &["test"]),
Err(UrlGenerationError::RouterNotAvailable)
);
let mut resource = ResourceHandler::<()>::default(); let mut resource = ResourceHandler::<()>::default();
resource.name("index"); resource.name("index");
let routes = vec!((Resource::new("index", "/user/{name}.{ext}"), Some(resource))); let routes = vec![
(
Resource::new("index", "/user/{name}.{ext}"),
Some(resource),
),
];
let (router, _) = Router::new("/", ServerSettings::default(), routes); let (router, _) = Router::new("/", ServerSettings::default(), routes);
assert!(router.has_route("/user/test.html")); assert!(router.has_route("/user/test.html"));
assert!(!router.has_route("/test/unknown")); assert!(!router.has_route("/test/unknown"));
@ -665,12 +694,19 @@ mod tests {
let req = TestRequest::with_header(header::HOST, "www.rust-lang.org") let req = TestRequest::with_header(header::HOST, "www.rust-lang.org")
.finish_with_router(router); .finish_with_router(router);
assert_eq!(req.url_for("unknown", &["test"]), assert_eq!(
Err(UrlGenerationError::ResourceNotFound)); req.url_for("unknown", &["test"]),
assert_eq!(req.url_for("index", &["test"]), Err(UrlGenerationError::ResourceNotFound)
Err(UrlGenerationError::NotEnoughElements)); );
assert_eq!(
req.url_for("index", &["test"]),
Err(UrlGenerationError::NotEnoughElements)
);
let url = req.url_for("index", &["test", "html"]); let url = req.url_for("index", &["test", "html"]);
assert_eq!(url.ok().unwrap().as_str(), "http://www.rust-lang.org/user/test.html"); assert_eq!(
url.ok().unwrap().as_str(),
"http://www.rust-lang.org/user/test.html"
);
} }
#[test] #[test]
@ -679,15 +715,22 @@ mod tests {
let mut resource = ResourceHandler::<()>::default(); let mut resource = ResourceHandler::<()>::default();
resource.name("index"); resource.name("index");
let routes = vec![(Resource::new("index", "/user/{name}.{ext}"), Some(resource))]; let routes = vec![
(
Resource::new("index", "/user/{name}.{ext}"),
Some(resource),
),
];
let (router, _) = Router::new("/prefix/", ServerSettings::default(), routes); let (router, _) = Router::new("/prefix/", ServerSettings::default(), routes);
assert!(router.has_route("/user/test.html")); assert!(router.has_route("/user/test.html"));
assert!(!router.has_route("/prefix/user/test.html")); assert!(!router.has_route("/prefix/user/test.html"));
let req = req.with_state(Rc::new(()), router); let req = req.with_state(Rc::new(()), router);
let url = req.url_for("index", &["test", "html"]); let url = req.url_for("index", &["test", "html"]);
assert_eq!(url.ok().unwrap().as_str(), assert_eq!(
"http://www.rust-lang.org/prefix/user/test.html"); url.ok().unwrap().as_str(),
"http://www.rust-lang.org/prefix/user/test.html"
);
} }
#[test] #[test]
@ -697,12 +740,19 @@ mod tests {
let mut resource = ResourceHandler::<()>::default(); let mut resource = ResourceHandler::<()>::default();
resource.name("index"); resource.name("index");
let routes = vec![ let routes = vec![
(Resource::external("youtube", "https://youtube.com/watch/{video_id}"), None)]; (
Resource::external("youtube", "https://youtube.com/watch/{video_id}"),
None,
),
];
let (router, _) = Router::new::<()>("", ServerSettings::default(), routes); let (router, _) = Router::new::<()>("", ServerSettings::default(), routes);
assert!(!router.has_route("https://youtube.com/watch/unknown")); assert!(!router.has_route("https://youtube.com/watch/unknown"));
let req = req.with_state(Rc::new(()), router); let req = req.with_state(Rc::new(()), router);
let url = req.url_for("youtube", &["oHg5SJYRHA0"]); let url = req.url_for("youtube", &["oHg5SJYRHA0"]);
assert_eq!(url.ok().unwrap().as_str(), "https://youtube.com/watch/oHg5SJYRHA0"); assert_eq!(
url.ok().unwrap().as_str(),
"https://youtube.com/watch/oHg5SJYRHA0"
);
} }
} }

View File

@ -1,30 +1,29 @@
//! Http response //! Http response
use std::{mem, str, fmt};
use std::rc::Rc;
use std::io::Write;
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::io::Write;
use std::rc::Rc;
use std::{fmt, mem, str};
use bytes::{BufMut, Bytes, BytesMut};
use cookie::{Cookie, CookieJar}; use cookie::{Cookie, CookieJar};
use bytes::{Bytes, BytesMut, BufMut};
use futures::Stream; use futures::Stream;
use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use serde_json; use http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode, Version};
use serde::Serialize; use serde::Serialize;
use serde_json;
use body::Body; use body::Body;
use client::ClientResponse;
use error::Error; use error::Error;
use handler::Responder; use handler::Responder;
use header::{Header, IntoHeaderValue, ContentEncoding}; use header::{ContentEncoding, Header, IntoHeaderValue};
use httprequest::HttpRequest;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use client::ClientResponse; use httprequest::HttpRequest;
/// max write buffer size 64k /// max write buffer size 64k
pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536;
/// Represents various types of connection /// Represents various types of connection
#[derive(Copy, Clone, PartialEq, Debug)] #[derive(Copy, Clone, PartialEq, Debug)]
pub enum ConnectionType { pub enum ConnectionType {
@ -37,7 +36,10 @@ pub enum ConnectionType {
} }
/// An HTTP Response /// An HTTP Response
pub struct HttpResponse(Option<Box<InnerHttpResponse>>, Rc<UnsafeCell<HttpResponsePool>>); pub struct HttpResponse(
Option<Box<InnerHttpResponse>>,
Rc<UnsafeCell<HttpResponsePool>>,
);
impl Drop for HttpResponse { impl Drop for HttpResponse {
fn drop(&mut self) { fn drop(&mut self) {
@ -48,7 +50,6 @@ impl Drop for HttpResponse {
} }
impl HttpResponse { impl HttpResponse {
#[inline(always)] #[inline(always)]
#[cfg_attr(feature = "cargo-clippy", allow(inline_always))] #[cfg_attr(feature = "cargo-clippy", allow(inline_always))]
fn get_ref(&self) -> &InnerHttpResponse { fn get_ref(&self) -> &InnerHttpResponse {
@ -103,7 +104,7 @@ impl HttpResponse {
response, response,
pool, pool,
err: None, err: None,
cookies: None, // TODO: convert set-cookie headers cookies: None, // TODO: convert set-cookie headers
} }
} }
@ -149,7 +150,10 @@ impl HttpResponse {
if let Some(reason) = self.get_ref().reason { if let Some(reason) = self.get_ref().reason {
reason reason
} else { } else {
self.get_ref().status.canonical_reason().unwrap_or("<unknown status code>") self.get_ref()
.status
.canonical_reason()
.unwrap_or("<unknown status code>")
} }
} }
@ -241,9 +245,13 @@ impl HttpResponse {
impl fmt::Debug for HttpResponse { impl fmt::Debug for HttpResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!(f, "\nHttpResponse {:?} {}{}", let res = writeln!(
self.get_ref().version, self.get_ref().status, f,
self.get_ref().reason.unwrap_or("")); "\nHttpResponse {:?} {}{}",
self.get_ref().version,
self.get_ref().status,
self.get_ref().reason.unwrap_or("")
);
let _ = writeln!(f, " encoding: {:?}", self.get_ref().encoding); let _ = writeln!(f, " encoding: {:?}", self.get_ref().encoding);
let _ = writeln!(f, " headers:"); let _ = writeln!(f, " headers:");
for (key, val) in self.get_ref().headers.iter() { for (key, val) in self.get_ref().headers.iter() {
@ -299,11 +307,12 @@ impl HttpResponseBuilder {
/// fn main() {} /// fn main() {}
/// ``` /// ```
#[doc(hidden)] #[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self {
{
if let Some(parts) = parts(&mut self.response, &self.err) { if let Some(parts) = parts(&mut self.response, &self.err) {
match hdr.try_into() { match hdr.try_into() {
Ok(value) => { parts.headers.append(H::name(), value); } Ok(value) => {
parts.headers.append(H::name(), value);
}
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
} }
} }
@ -325,16 +334,17 @@ impl HttpResponseBuilder {
/// fn main() {} /// fn main() {}
/// ``` /// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where HeaderName: HttpTryFrom<K>, where
V: IntoHeaderValue, HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{ {
if let Some(parts) = parts(&mut self.response, &self.err) { if let Some(parts) = parts(&mut self.response, &self.err) {
match HeaderName::try_from(key) { match HeaderName::try_from(key) {
Ok(key) => { Ok(key) => match value.try_into() {
match value.try_into() { Ok(value) => {
Ok(value) => { parts.headers.append(key, value); } parts.headers.append(key, value);
Err(e) => self.err = Some(e.into()),
} }
Err(e) => self.err = Some(e.into()),
}, },
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
}; };
@ -354,8 +364,9 @@ impl HttpResponseBuilder {
/// Set content encoding. /// Set content encoding.
/// ///
/// By default `ContentEncoding::Auto` is used, which automatically /// By default `ContentEncoding::Auto` is used, which automatically
/// negotiates content encoding based on request's `Accept-Encoding` headers. /// negotiates content encoding based on request's `Accept-Encoding`
/// To enforce specific encoding, use specific ContentEncoding` value. /// headers. To enforce specific encoding, use specific
/// ContentEncoding` value.
#[inline] #[inline]
pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self {
if let Some(parts) = parts(&mut self.response, &self.err) { if let Some(parts) = parts(&mut self.response, &self.err) {
@ -408,11 +419,14 @@ impl HttpResponseBuilder {
/// Set response content type /// Set response content type
#[inline] #[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self pub fn content_type<V>(&mut self, value: V) -> &mut Self
where HeaderValue: HttpTryFrom<V> where
HeaderValue: HttpTryFrom<V>,
{ {
if let Some(parts) = parts(&mut self.response, &self.err) { if let Some(parts) = parts(&mut self.response, &self.err) {
match HeaderValue::try_from(value) { match HeaderValue::try_from(value) {
Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); }, Ok(value) => {
parts.headers.insert(header::CONTENT_TYPE, value);
}
Err(e) => self.err = Some(e.into()), Err(e) => self.err = Some(e.into()),
}; };
} }
@ -452,12 +466,16 @@ impl HttpResponseBuilder {
jar.add(cookie.into_owned()); jar.add(cookie.into_owned());
self.cookies = Some(jar) self.cookies = Some(jar)
} else { } else {
self.cookies.as_mut().unwrap().add(cookie.into_owned()); self.cookies
.as_mut()
.unwrap()
.add(cookie.into_owned());
} }
self self
} }
/// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()` method. /// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()`
/// method.
pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self {
{ {
if self.cookies.is_none() { if self.cookies.is_none() {
@ -471,9 +489,11 @@ impl HttpResponseBuilder {
self self
} }
/// This method calls provided closure with builder reference if value is true. /// This method calls provided closure with builder reference if value is
/// true.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where F: FnOnce(&mut HttpResponseBuilder) where
F: FnOnce(&mut HttpResponseBuilder),
{ {
if value { if value {
f(self); f(self);
@ -481,9 +501,11 @@ impl HttpResponseBuilder {
self self
} }
/// This method calls provided closure with builder reference if value is Some. /// This method calls provided closure with builder reference if value is
/// Some.
pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self
where F: FnOnce(T, &mut HttpResponseBuilder) where
F: FnOnce(T, &mut HttpResponseBuilder),
{ {
if let Some(val) = value { if let Some(val) = value {
f(val, self); f(val, self);
@ -494,8 +516,8 @@ impl HttpResponseBuilder {
/// Set write buffer capacity /// Set write buffer capacity
/// ///
/// This parameter makes sense only for streaming response /// This parameter makes sense only for streaming response
/// or actor. If write buffer reaches specified capacity, stream or actor get /// or actor. If write buffer reaches specified capacity, stream or actor
/// paused. /// get paused.
/// ///
/// Default write buffer capacity is 64kb /// Default write buffer capacity is 64kb
pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self { pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self {
@ -510,9 +532,11 @@ impl HttpResponseBuilder {
/// `HttpResponseBuilder` can not be used after this call. /// `HttpResponseBuilder` can not be used after this call.
pub fn body<B: Into<Body>>(&mut self, body: B) -> HttpResponse { pub fn body<B: Into<Body>>(&mut self, body: B) -> HttpResponse {
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
return Error::from(e).into() return Error::from(e).into();
} }
let mut response = self.response.take().expect("cannot reuse response builder"); let mut response = self.response
.take()
.expect("cannot reuse response builder");
if let Some(ref jar) = self.cookies { if let Some(ref jar) = self.cookies {
for cookie in jar.delta() { for cookie in jar.delta() {
match HeaderValue::from_str(&cookie.to_string()) { match HeaderValue::from_str(&cookie.to_string()) {
@ -530,10 +554,13 @@ impl HttpResponseBuilder {
/// ///
/// `HttpResponseBuilder` can not be used after this call. /// `HttpResponseBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> HttpResponse pub fn streaming<S, E>(&mut self, stream: S) -> HttpResponse
where S: Stream<Item=Bytes, Error=E> + 'static, where
E: Into<Error>, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error>,
{ {
self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into())))) self.body(Body::Streaming(Box::new(
stream.map_err(|e| e.into()),
)))
} }
/// Set a json body and generate `HttpResponse` /// Set a json body and generate `HttpResponse`
@ -542,19 +569,19 @@ impl HttpResponseBuilder {
pub fn json<T: Serialize>(&mut self, value: T) -> HttpResponse { pub fn json<T: Serialize>(&mut self, value: T) -> HttpResponse {
match serde_json::to_string(&value) { match serde_json::to_string(&value) {
Ok(body) => { Ok(body) => {
let contains = let contains = if let Some(parts) = parts(&mut self.response, &self.err)
if let Some(parts) = parts(&mut self.response, &self.err) { {
parts.headers.contains_key(header::CONTENT_TYPE) parts.headers.contains_key(header::CONTENT_TYPE)
} else { } else {
true true
}; };
if !contains { if !contains {
self.header(header::CONTENT_TYPE, "application/json"); self.header(header::CONTENT_TYPE, "application/json");
} }
self.body(body) self.body(body)
}, }
Err(e) => Error::from(e).into() Err(e) => Error::from(e).into(),
} }
} }
@ -579,11 +606,11 @@ impl HttpResponseBuilder {
#[inline] #[inline]
#[cfg_attr(feature = "cargo-clippy", allow(borrowed_box))] #[cfg_attr(feature = "cargo-clippy", allow(borrowed_box))]
fn parts<'a>(parts: &'a mut Option<Box<InnerHttpResponse>>, err: &Option<HttpError>) fn parts<'a>(
-> Option<&'a mut Box<InnerHttpResponse>> parts: &'a mut Option<Box<InnerHttpResponse>>, err: &Option<HttpError>
{ ) -> Option<&'a mut Box<InnerHttpResponse>> {
if err.is_some() { if err.is_some() {
return None return None;
} }
parts.as_mut() parts.as_mut()
} }
@ -628,8 +655,8 @@ impl Responder for &'static str {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK) Ok(req.build_response(StatusCode::OK)
.content_type("text/plain; charset=utf-8") .content_type("text/plain; charset=utf-8")
.body(self)) .body(self))
} }
} }
@ -647,8 +674,8 @@ impl Responder for &'static [u8] {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK) Ok(req.build_response(StatusCode::OK)
.content_type("application/octet-stream") .content_type("application/octet-stream")
.body(self)) .body(self))
} }
} }
@ -666,8 +693,8 @@ impl Responder for String {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK) Ok(req.build_response(StatusCode::OK)
.content_type("text/plain; charset=utf-8") .content_type("text/plain; charset=utf-8")
.body(self)) .body(self))
} }
} }
@ -685,8 +712,8 @@ impl<'a> Responder for &'a String {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK) Ok(req.build_response(StatusCode::OK)
.content_type("text/plain; charset=utf-8") .content_type("text/plain; charset=utf-8")
.body(self)) .body(self))
} }
} }
@ -704,8 +731,8 @@ impl Responder for Bytes {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK) Ok(req.build_response(StatusCode::OK)
.content_type("application/octet-stream") .content_type("application/octet-stream")
.body(self)) .body(self))
} }
} }
@ -723,8 +750,8 @@ impl Responder for BytesMut {
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> { fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
Ok(req.build_response(StatusCode::OK) Ok(req.build_response(StatusCode::OK)
.content_type("application/octet-stream") .content_type("application/octet-stream")
.body(self)) .body(self))
} }
} }
@ -745,7 +772,9 @@ impl<'a> From<&'a ClientResponse> for HttpResponseBuilder {
impl<'a, S> From<&'a HttpRequest<S>> for HttpResponseBuilder { impl<'a, S> From<&'a HttpRequest<S>> for HttpResponseBuilder {
fn from(req: &'a HttpRequest<S>) -> HttpResponseBuilder { fn from(req: &'a HttpRequest<S>) -> HttpResponseBuilder {
if let Some(router) = req.router() { if let Some(router) = req.router() {
router.server_settings().get_response_builder(StatusCode::OK) router
.server_settings()
.get_response_builder(StatusCode::OK)
} else { } else {
HttpResponse::Ok() HttpResponse::Ok()
} }
@ -768,7 +797,6 @@ struct InnerHttpResponse {
} }
impl InnerHttpResponse { impl InnerHttpResponse {
#[inline] #[inline]
fn new(status: StatusCode, body: Body) -> InnerHttpResponse { fn new(status: StatusCode, body: Body) -> InnerHttpResponse {
InnerHttpResponse { InnerHttpResponse {
@ -793,38 +821,41 @@ pub(crate) struct HttpResponsePool(VecDeque<Box<InnerHttpResponse>>);
thread_local!(static POOL: Rc<UnsafeCell<HttpResponsePool>> = HttpResponsePool::pool()); thread_local!(static POOL: Rc<UnsafeCell<HttpResponsePool>> = HttpResponsePool::pool());
impl HttpResponsePool { impl HttpResponsePool {
pub fn pool() -> Rc<UnsafeCell<HttpResponsePool>> { pub fn pool() -> Rc<UnsafeCell<HttpResponsePool>> {
Rc::new(UnsafeCell::new(HttpResponsePool(VecDeque::with_capacity(128)))) Rc::new(UnsafeCell::new(HttpResponsePool(
VecDeque::with_capacity(128),
)))
} }
#[inline] #[inline]
pub fn get_builder(pool: &Rc<UnsafeCell<HttpResponsePool>>, status: StatusCode) pub fn get_builder(
-> HttpResponseBuilder pool: &Rc<UnsafeCell<HttpResponsePool>>, status: StatusCode
{ ) -> HttpResponseBuilder {
let p = unsafe{&mut *pool.as_ref().get()}; let p = unsafe { &mut *pool.as_ref().get() };
if let Some(mut msg) = p.0.pop_front() { if let Some(mut msg) = p.0.pop_front() {
msg.status = status; msg.status = status;
HttpResponseBuilder { HttpResponseBuilder {
response: Some(msg), response: Some(msg),
pool: Some(Rc::clone(pool)), pool: Some(Rc::clone(pool)),
err: None, err: None,
cookies: None } cookies: None,
}
} else { } else {
let msg = Box::new(InnerHttpResponse::new(status, Body::Empty)); let msg = Box::new(InnerHttpResponse::new(status, Body::Empty));
HttpResponseBuilder { HttpResponseBuilder {
response: Some(msg), response: Some(msg),
pool: Some(Rc::clone(pool)), pool: Some(Rc::clone(pool)),
err: None, err: None,
cookies: None } cookies: None,
}
} }
} }
#[inline] #[inline]
pub fn get_response(pool: &Rc<UnsafeCell<HttpResponsePool>>, pub fn get_response(
status: StatusCode, body: Body) -> HttpResponse pool: &Rc<UnsafeCell<HttpResponsePool>>, status: StatusCode, body: Body
{ ) -> HttpResponse {
let p = unsafe{&mut *pool.as_ref().get()}; let p = unsafe { &mut *pool.as_ref().get() };
if let Some(mut msg) = p.0.pop_front() { if let Some(mut msg) = p.0.pop_front() {
msg.status = status; msg.status = status;
msg.body = body; msg.body = body;
@ -847,9 +878,10 @@ impl HttpResponsePool {
#[inline(always)] #[inline(always)]
#[cfg_attr(feature = "cargo-clippy", allow(boxed_local, inline_always))] #[cfg_attr(feature = "cargo-clippy", allow(boxed_local, inline_always))]
fn release(pool: &Rc<UnsafeCell<HttpResponsePool>>, mut inner: Box<InnerHttpResponse>) fn release(
{ pool: &Rc<UnsafeCell<HttpResponsePool>>, mut inner: Box<InnerHttpResponse>
let pool = unsafe{&mut *pool.as_ref().get()}; ) {
let pool = unsafe { &mut *pool.as_ref().get() };
if pool.0.len() < 128 { if pool.0.len() < 128 {
inner.headers.clear(); inner.headers.clear();
inner.version = None; inner.version = None;
@ -868,12 +900,12 @@ impl HttpResponsePool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::str::FromStr;
use time::Duration;
use http::{Method, Uri};
use http::header::{COOKIE, CONTENT_TYPE, HeaderValue};
use body::Binary; use body::Binary;
use http; use http;
use http::header::{HeaderValue, CONTENT_TYPE, COOKIE};
use http::{Method, Uri};
use std::str::FromStr;
use time::Duration;
#[test] #[test]
fn test_debug() { fn test_debug() {
@ -892,25 +924,37 @@ mod tests {
headers.insert(COOKIE, HeaderValue::from_static("cookie2=value2")); headers.insert(COOKIE, HeaderValue::from_static("cookie2=value2"));
let req = HttpRequest::new( let req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let cookies = req.cookies().unwrap(); let cookies = req.cookies().unwrap();
let resp = HttpResponse::Ok() let resp = HttpResponse::Ok()
.cookie(http::Cookie::build("name", "value") .cookie(
http::Cookie::build("name", "value")
.domain("www.rust-lang.org") .domain("www.rust-lang.org")
.path("/test") .path("/test")
.http_only(true) .http_only(true)
.max_age(Duration::days(1)) .max_age(Duration::days(1))
.finish()) .finish(),
)
.del_cookie(&cookies[0]) .del_cookie(&cookies[0])
.finish(); .finish();
let mut val: Vec<_> = resp.headers().get_all("Set-Cookie") let mut val: Vec<_> = resp.headers()
.iter().map(|v| v.to_str().unwrap().to_owned()).collect(); .get_all("Set-Cookie")
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect();
val.sort(); val.sort();
assert!(val[0].starts_with("cookie2=; Max-Age=0;")); assert!(val[0].starts_with("cookie2=; Max-Age=0;"));
assert_eq!( assert_eq!(
val[1],"name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400"); val[1],
"name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400"
);
} }
#[test] #[test]
@ -931,15 +975,21 @@ mod tests {
#[test] #[test]
fn test_force_close() { fn test_force_close() {
let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); let resp = HttpResponse::build(StatusCode::OK)
.force_close()
.finish();
assert!(!resp.keep_alive().unwrap()) assert!(!resp.keep_alive().unwrap())
} }
#[test] #[test]
fn test_content_type() { fn test_content_type() {
let resp = HttpResponse::build(StatusCode::OK) let resp = HttpResponse::build(StatusCode::OK)
.content_type("text/plain").body(Body::Empty); .content_type("text/plain")
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") .body(Body::Empty);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
"text/plain"
)
} }
#[test] #[test]
@ -947,25 +997,29 @@ mod tests {
let resp = HttpResponse::build(StatusCode::OK).finish(); let resp = HttpResponse::build(StatusCode::OK).finish();
assert_eq!(resp.content_encoding(), None); assert_eq!(resp.content_encoding(), None);
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
{ {
let resp = HttpResponse::build(StatusCode::OK) let resp = HttpResponse::build(StatusCode::OK)
.content_encoding(ContentEncoding::Br).finish(); .content_encoding(ContentEncoding::Br)
.finish();
assert_eq!(resp.content_encoding(), Some(ContentEncoding::Br)); assert_eq!(resp.content_encoding(), Some(ContentEncoding::Br));
} }
let resp = HttpResponse::build(StatusCode::OK) let resp = HttpResponse::build(StatusCode::OK)
.content_encoding(ContentEncoding::Gzip).finish(); .content_encoding(ContentEncoding::Gzip)
.finish();
assert_eq!(resp.content_encoding(), Some(ContentEncoding::Gzip)); assert_eq!(resp.content_encoding(), Some(ContentEncoding::Gzip));
} }
#[test] #[test]
fn test_json() { fn test_json() {
let resp = HttpResponse::build(StatusCode::OK) let resp = HttpResponse::build(StatusCode::OK).json(vec!["v1", "v2", "v3"]);
.json(vec!["v1", "v2", "v3"]);
let ct = resp.headers().get(CONTENT_TYPE).unwrap(); let ct = resp.headers().get(CONTENT_TYPE).unwrap();
assert_eq!(ct, HeaderValue::from_static("application/json")); assert_eq!(ct, HeaderValue::from_static("application/json"));
assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))); assert_eq!(
*resp.body(),
Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))
);
} }
#[test] #[test]
@ -975,7 +1029,10 @@ mod tests {
.json(vec!["v1", "v2", "v3"]); .json(vec!["v1", "v2", "v3"]);
let ct = resp.headers().get(CONTENT_TYPE).unwrap(); let ct = resp.headers().get(CONTENT_TYPE).unwrap();
assert_eq!(ct, HeaderValue::from_static("text/json")); assert_eq!(ct, HeaderValue::from_static("text/json"));
assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))); assert_eq!(
*resp.body(),
Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))
);
} }
impl Body { impl Body {
@ -993,91 +1050,152 @@ mod tests {
let resp: HttpResponse = "test".into(); let resp: HttpResponse = "test".into();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("text/plain; charset=utf-8")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test")); assert_eq!(resp.body().binary().unwrap(), &Binary::from("test"));
let resp: HttpResponse = "test".respond_to(req.clone()).ok().unwrap(); let resp: HttpResponse = "test".respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("text/plain; charset=utf-8")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test")); assert_eq!(resp.body().binary().unwrap(), &Binary::from("test"));
let resp: HttpResponse = b"test".as_ref().into(); let resp: HttpResponse = b"test".as_ref().into();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("application/octet-stream")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref())); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(b"test".as_ref())
);
let resp: HttpResponse = b"test".as_ref().respond_to(req.clone()).ok().unwrap(); let resp: HttpResponse = b"test".as_ref().respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("application/octet-stream")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref())); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(b"test".as_ref())
);
let resp: HttpResponse = "test".to_owned().into(); let resp: HttpResponse = "test".to_owned().into();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("text/plain; charset=utf-8")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned())); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from("test".to_owned())
);
let resp: HttpResponse = "test".to_owned().respond_to(req.clone()).ok().unwrap(); let resp: HttpResponse = "test"
.to_owned()
.respond_to(req.clone())
.ok()
.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("text/plain; charset=utf-8")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned())); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from("test".to_owned())
);
let resp: HttpResponse = (&"test".to_owned()).into(); let resp: HttpResponse = (&"test".to_owned()).into();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("text/plain; charset=utf-8")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned())); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(&"test".to_owned())
);
let resp: HttpResponse = (&"test".to_owned()).respond_to(req.clone()).ok().unwrap(); let resp: HttpResponse = (&"test".to_owned())
.respond_to(req.clone())
.ok()
.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("text/plain; charset=utf-8")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned())); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(&"test".to_owned())
);
let b = Bytes::from_static(b"test"); let b = Bytes::from_static(b"test");
let resp: HttpResponse = b.into(); let resp: HttpResponse = b.into();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("application/octet-stream")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test"))); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(Bytes::from_static(b"test"))
);
let b = Bytes::from_static(b"test"); let b = Bytes::from_static(b"test");
let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap(); let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("application/octet-stream")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test"))); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(Bytes::from_static(b"test"))
);
let b = BytesMut::from("test"); let b = BytesMut::from("test");
let resp: HttpResponse = b.into(); let resp: HttpResponse = b.into();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("application/octet-stream")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test"))); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(BytesMut::from("test"))
);
let b = BytesMut::from("test"); let b = BytesMut::from("test");
let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap(); let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), assert_eq!(
HeaderValue::from_static("application/octet-stream")); resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test"))); assert_eq!(
resp.body().binary().unwrap(),
&Binary::from(BytesMut::from("test"))
);
} }
#[test] #[test]

View File

@ -1,13 +1,12 @@
use std::str::FromStr;
use http::header::{self, HeaderName}; use http::header::{self, HeaderName};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use std::str::FromStr;
const X_FORWARDED_FOR: &str = "X-FORWARDED-FOR"; const X_FORWARDED_FOR: &str = "X-FORWARDED-FOR";
const X_FORWARDED_HOST: &str = "X-FORWARDED-HOST"; const X_FORWARDED_HOST: &str = "X-FORWARDED-HOST";
const X_FORWARDED_PROTO: &str = "X-FORWARDED-PROTO"; const X_FORWARDED_PROTO: &str = "X-FORWARDED-PROTO";
/// `HttpRequest` connection information /// `HttpRequest` connection information
pub struct ConnectionInfo<'a> { pub struct ConnectionInfo<'a> {
scheme: &'a str, scheme: &'a str,
@ -17,7 +16,6 @@ pub struct ConnectionInfo<'a> {
} }
impl<'a> ConnectionInfo<'a> { impl<'a> ConnectionInfo<'a> {
/// Create *ConnectionInfo* instance for a request. /// Create *ConnectionInfo* instance for a request.
#[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))] #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))]
pub fn new<S>(req: &'a HttpRequest<S>) -> ConnectionInfo<'a> { pub fn new<S>(req: &'a HttpRequest<S>) -> ConnectionInfo<'a> {
@ -55,8 +53,9 @@ impl<'a> ConnectionInfo<'a> {
// scheme // scheme
if scheme.is_none() { if scheme.is_none() {
if let Some(h) = req.headers().get( if let Some(h) = req.headers()
HeaderName::from_str(X_FORWARDED_PROTO).unwrap()) { .get(HeaderName::from_str(X_FORWARDED_PROTO).unwrap())
{
if let Ok(h) = h.to_str() { if let Ok(h) = h.to_str() {
scheme = h.split(',').next().map(|v| v.trim()); scheme = h.split(',').next().map(|v| v.trim());
} }
@ -75,7 +74,9 @@ impl<'a> ConnectionInfo<'a> {
// host // host
if host.is_none() { if host.is_none() {
if let Some(h) = req.headers().get(HeaderName::from_str(X_FORWARDED_HOST).unwrap()) { if let Some(h) = req.headers()
.get(HeaderName::from_str(X_FORWARDED_HOST).unwrap())
{
if let Ok(h) = h.to_str() { if let Ok(h) = h.to_str() {
host = h.split(',').next().map(|v| v.trim()); host = h.split(',').next().map(|v| v.trim());
} }
@ -97,13 +98,15 @@ impl<'a> ConnectionInfo<'a> {
// remote addr // remote addr
if remote.is_none() { if remote.is_none() {
if let Some(h) = req.headers().get( if let Some(h) = req.headers()
HeaderName::from_str(X_FORWARDED_FOR).unwrap()) { .get(HeaderName::from_str(X_FORWARDED_FOR).unwrap())
{
if let Ok(h) = h.to_str() { if let Ok(h) = h.to_str() {
remote = h.split(',').next().map(|v| v.trim()); remote = h.split(',').next().map(|v| v.trim());
} }
} }
if remote.is_none() { // get peeraddr from socketaddr if remote.is_none() {
// get peeraddr from socketaddr
peer = req.peer_addr().map(|addr| format!("{}", addr)); peer = req.peer_addr().map(|addr| format!("{}", addr));
} }
} }
@ -176,7 +179,9 @@ mod tests {
req.headers_mut().insert( req.headers_mut().insert(
header::FORWARDED, header::FORWARDED,
HeaderValue::from_static( HeaderValue::from_static(
"for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org")); "for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org",
),
);
let info = ConnectionInfo::new(&req); let info = ConnectionInfo::new(&req);
assert_eq!(info.scheme(), "https"); assert_eq!(info.scheme(), "https");
@ -185,7 +190,9 @@ mod tests {
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert( req.headers_mut().insert(
header::HOST, HeaderValue::from_static("rust-lang.org")); header::HOST,
HeaderValue::from_static("rust-lang.org"),
);
let info = ConnectionInfo::new(&req); let info = ConnectionInfo::new(&req);
assert_eq!(info.scheme(), "http"); assert_eq!(info.scheme(), "http");
@ -194,20 +201,26 @@ mod tests {
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert( req.headers_mut().insert(
HeaderName::from_str(X_FORWARDED_FOR).unwrap(), HeaderValue::from_static("192.0.2.60")); HeaderName::from_str(X_FORWARDED_FOR).unwrap(),
HeaderValue::from_static("192.0.2.60"),
);
let info = ConnectionInfo::new(&req); let info = ConnectionInfo::new(&req);
assert_eq!(info.remote(), Some("192.0.2.60")); assert_eq!(info.remote(), Some("192.0.2.60"));
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert( req.headers_mut().insert(
HeaderName::from_str(X_FORWARDED_HOST).unwrap(), HeaderValue::from_static("192.0.2.60")); HeaderName::from_str(X_FORWARDED_HOST).unwrap(),
HeaderValue::from_static("192.0.2.60"),
);
let info = ConnectionInfo::new(&req); let info = ConnectionInfo::new(&req);
assert_eq!(info.host(), "192.0.2.60"); assert_eq!(info.host(), "192.0.2.60");
assert_eq!(info.remote(), None); assert_eq!(info.remote(), None);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert( req.headers_mut().insert(
HeaderName::from_str(X_FORWARDED_PROTO).unwrap(), HeaderValue::from_static("https")); HeaderName::from_str(X_FORWARDED_PROTO).unwrap(),
HeaderValue::from_static("https"),
);
let info = ConnectionInfo::new(&req); let info = ConnectionInfo::new(&req);
assert_eq!(info.scheme(), "https"); assert_eq!(info.scheme(), "https");
} }

View File

@ -1,16 +1,16 @@
use bytes::{Bytes, BytesMut};
use futures::{Future, Poll, Stream};
use http::header::CONTENT_LENGTH;
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use bytes::{Bytes, BytesMut};
use futures::{Poll, Future, Stream};
use http::header::CONTENT_LENGTH;
use mime; use mime;
use serde_json;
use serde::Serialize; use serde::Serialize;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde_json;
use error::{Error, JsonPayloadError, PayloadError}; use error::{Error, JsonPayloadError, PayloadError};
use handler::{Responder, FromRequest}; use handler::{FromRequest, Responder};
use http::StatusCode; use http::StatusCode;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
@ -18,80 +18,15 @@ use httpresponse::HttpResponse;
/// Json helper /// Json helper
/// ///
/// Json can be used for two different purpose. First is for json response generation /// Json can be used for two different purpose. First is for json response
/// and second is for extracting typed information from request's payload. /// generation and second is for extracting typed information from request's
pub struct Json<T>(pub T); /// payload.
impl<T> Json<T> {
/// Deconstruct to an inner value
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> fmt::Debug for Json<T> where T: fmt::Debug {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Json: {:?}", self.0)
}
}
impl<T> fmt::Display for Json<T> where T: fmt::Display {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
/// The `Json` type allows you to respond with well-formed JSON data: simply
/// return a value of type Json<T> where T is the type of a structure
/// to serialize into *JSON*. The type `T` must implement the `Serialize`
/// trait from *serde*.
/// ///
/// ```rust /// To extract typed information from request's body, the type `T` must
/// # extern crate actix_web; /// implement the `Deserialize` trait from *serde*.
/// # #[macro_use] extern crate serde_derive;
/// # use actix_web::*;
/// #
/// #[derive(Serialize)]
/// struct MyObj {
/// name: String,
/// }
/// ///
/// fn index(req: HttpRequest) -> Result<Json<MyObj>> { /// [**JsonConfig**](dev/struct.JsonConfig.html) allows to configure extraction
/// Ok(Json(MyObj{name: req.match_info().query("name")?})) /// process.
/// }
/// # fn main() {}
/// ```
impl<T: Serialize> Responder for Json<T> {
type Item = HttpResponse;
type Error = Error;
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
let body = serde_json::to_string(&self.0)?;
Ok(req.build_response(StatusCode::OK)
.content_type("application/json")
.body(body))
}
}
/// To extract typed information from request's body, the type `T` must implement the
/// `Deserialize` trait from *serde*.
///
/// [**JsonConfig**](dev/struct.JsonConfig.html) allows to configure extraction process.
/// ///
/// ## Example /// ## Example
/// ///
@ -116,11 +51,88 @@ impl<T: Serialize> Responder for Json<T> {
/// |r| r.method(http::Method::POST).with(index)); // <- use `with` extractor /// |r| r.method(http::Method::POST).with(index)); // <- use `with` extractor
/// } /// }
/// ``` /// ```
///
/// The `Json` type allows you to respond with well-formed JSON data: simply
/// return a value of type Json<T> where T is the type of a structure
/// to serialize into *JSON*. The type `T` must implement the `Serialize`
/// trait from *serde*.
///
/// ```rust
/// # extern crate actix_web;
/// # #[macro_use] extern crate serde_derive;
/// # use actix_web::*;
/// #
/// #[derive(Serialize)]
/// struct MyObj {
/// name: String,
/// }
///
/// fn index(req: HttpRequest) -> Result<Json<MyObj>> {
/// Ok(Json(MyObj{name: req.match_info().query("name")?}))
/// }
/// # fn main() {}
/// ```
pub struct Json<T>(pub T);
impl<T> Json<T> {
/// Deconstruct to an inner value
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for Json<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> fmt::Debug for Json<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Json: {:?}", self.0)
}
}
impl<T> fmt::Display for Json<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl<T: Serialize> Responder for Json<T> {
type Item = HttpResponse;
type Error = Error;
fn respond_to(self, req: HttpRequest) -> Result<HttpResponse, Error> {
let body = serde_json::to_string(&self.0)?;
Ok(req.build_response(StatusCode::OK)
.content_type("application/json")
.body(body))
}
}
impl<T, S> FromRequest<S> for Json<T> impl<T, S> FromRequest<S> for Json<T>
where T: DeserializeOwned + 'static, S: 'static where
T: DeserializeOwned + 'static,
S: 'static,
{ {
type Config = JsonConfig; type Config = JsonConfig;
type Result = Box<Future<Item=Self, Error=Error>>; type Result = Box<Future<Item = Self, Error = Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
@ -128,7 +140,8 @@ impl<T, S> FromRequest<S> for Json<T>
JsonBody::new(req.clone()) JsonBody::new(req.clone())
.limit(cfg.limit) .limit(cfg.limit)
.from_err() .from_err()
.map(Json)) .map(Json),
)
} }
} }
@ -163,7 +176,6 @@ pub struct JsonConfig {
} }
impl JsonConfig { impl JsonConfig {
/// Change max size of payload. By default max size is 256Kb /// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self { pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit; self.limit = limit;
@ -173,7 +185,7 @@ impl JsonConfig {
impl Default for JsonConfig { impl Default for JsonConfig {
fn default() -> Self { fn default() -> Self {
JsonConfig{limit: 262_144} JsonConfig { limit: 262_144 }
} }
} }
@ -208,17 +220,16 @@ impl Default for JsonConfig {
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
pub struct JsonBody<T, U: DeserializeOwned>{ pub struct JsonBody<T, U: DeserializeOwned> {
limit: usize, limit: usize,
req: Option<T>, req: Option<T>,
fut: Option<Box<Future<Item=U, Error=JsonPayloadError>>>, fut: Option<Box<Future<Item = U, Error = JsonPayloadError>>>,
} }
impl<T, U: DeserializeOwned> JsonBody<T, U> { impl<T, U: DeserializeOwned> JsonBody<T, U> {
/// Create `JsonBody` for request. /// Create `JsonBody` for request.
pub fn new(req: T) -> Self { pub fn new(req: T) -> Self {
JsonBody{ JsonBody {
limit: 262_144, limit: 262_144,
req: Some(req), req: Some(req),
fut: None, fut: None,
@ -233,7 +244,8 @@ impl<T, U: DeserializeOwned> JsonBody<T, U> {
} }
impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U> impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
where T: HttpMessage + Stream<Item=Bytes, Error=PayloadError> + 'static where
T: HttpMessage + Stream<Item = Bytes, Error = PayloadError> + 'static,
{ {
type Item = U; type Item = U;
type Error = JsonPayloadError; type Error = JsonPayloadError;
@ -259,7 +271,7 @@ impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
false false
}; };
if !json { if !json {
return Err(JsonPayloadError::ContentType) return Err(JsonPayloadError::ContentType);
} }
let limit = self.limit; let limit = self.limit;
@ -276,7 +288,10 @@ impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
self.fut = Some(Box::new(fut)); self.fut = Some(Box::new(fut));
} }
self.fut.as_mut().expect("JsonBody could not be used second time").poll() self.fut
.as_mut()
.expect("JsonBody could not be used second time")
.poll()
} }
} }
@ -284,11 +299,11 @@ impl<T, U: DeserializeOwned + 'static> Future for JsonBody<T, U>
mod tests { mod tests {
use super::*; use super::*;
use bytes::Bytes; use bytes::Bytes;
use http::header;
use futures::Async; use futures::Async;
use http::header;
use with::{With, ExtractorConfig};
use handler::Handler; use handler::Handler;
use with::{ExtractorConfig, With};
impl PartialEq for JsonPayloadError { impl PartialEq for JsonPayloadError {
fn eq(&self, other: &JsonPayloadError) -> bool { fn eq(&self, other: &JsonPayloadError) -> bool {
@ -313,59 +328,100 @@ mod tests {
#[test] #[test]
fn test_json() { fn test_json() {
let json = Json(MyObject{name: "test".to_owned()}); let json = Json(MyObject {
name: "test".to_owned(),
});
let resp = json.respond_to(HttpRequest::default()).unwrap(); let resp = json.respond_to(HttpRequest::default()).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "application/json"); assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json"
);
} }
#[test] #[test]
fn test_json_body() { fn test_json_body() {
let req = HttpRequest::default(); let req = HttpRequest::default();
let mut json = req.json::<MyObject>(); let mut json = req.json::<MyObject>();
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); assert_eq!(
json.poll().err().unwrap(),
JsonPayloadError::ContentType
);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE, req.headers_mut().insert(
header::HeaderValue::from_static("application/text")); header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"),
);
let mut json = req.json::<MyObject>(); let mut json = req.json::<MyObject>();
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); assert_eq!(
json.poll().err().unwrap(),
JsonPayloadError::ContentType
);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE, req.headers_mut().insert(
header::HeaderValue::from_static("application/json")); header::CONTENT_TYPE,
req.headers_mut().insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("application/json"),
header::HeaderValue::from_static("10000")); );
req.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"),
);
let mut json = req.json::<MyObject>().limit(100); let mut json = req.json::<MyObject>().limit(100);
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE, req.headers_mut().insert(
header::HeaderValue::from_static("application/json")); header::CONTENT_TYPE,
req.headers_mut().insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("application/json"),
header::HeaderValue::from_static("16")); );
req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}")); req.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
);
req.payload_mut()
.unread_data(Bytes::from_static(b"{\"name\": \"test\"}"));
let mut json = req.json::<MyObject>(); let mut json = req.json::<MyObject>();
assert_eq!(json.poll().ok().unwrap(), assert_eq!(
Async::Ready(MyObject{name: "test".to_owned()})); json.poll().ok().unwrap(),
Async::Ready(MyObject {
name: "test".to_owned()
})
);
} }
#[test] #[test]
fn test_with_json() { fn test_with_json() {
let mut cfg = ExtractorConfig::<_, Json<MyObject>>::default(); let mut cfg = ExtractorConfig::<_, Json<MyObject>>::default();
cfg.limit(4096); cfg.limit(4096);
let mut handler = With::new(|data: Json<MyObject>| {data}, cfg); let mut handler = With::new(|data: Json<MyObject>| data, cfg);
let req = HttpRequest::default(); let req = HttpRequest::default();
let err = handler.handle(req).as_response().unwrap().error().is_some(); let err = handler
.handle(req)
.as_response()
.unwrap()
.error()
.is_some();
assert!(err); assert!(err);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
req.headers_mut().insert(header::CONTENT_TYPE, req.headers_mut().insert(
header::HeaderValue::from_static("application/json")); header::CONTENT_TYPE,
req.headers_mut().insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("application/json"),
header::HeaderValue::from_static("16")); );
req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}")); req.headers_mut().insert(
let ok = handler.handle(req).as_response().unwrap().error().is_none(); header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
);
req.payload_mut()
.unread_data(Bytes::from_static(b"{\"name\": \"test\"}"));
let ok = handler
.handle(req)
.as_response()
.unwrap()
.error()
.is_none();
assert!(ok) assert!(ok)
} }
} }

View File

@ -64,17 +64,17 @@
#![cfg_attr(actix_nightly, feature( #![cfg_attr(actix_nightly, feature(
specialization, // for impl ErrorResponse for std::error::Error specialization, // for impl ErrorResponse for std::error::Error
))] ))]
#![cfg_attr(feature = "cargo-clippy", allow( #![cfg_attr(feature = "cargo-clippy",
decimal_literal_representation,suspicious_arithmetic_impl,))] allow(decimal_literal_representation, suspicious_arithmetic_impl))]
#[macro_use] #[macro_use]
extern crate log; extern crate log;
extern crate time;
extern crate base64; extern crate base64;
extern crate bytes;
extern crate byteorder; extern crate byteorder;
extern crate sha1; extern crate bytes;
extern crate regex; extern crate regex;
extern crate sha1;
extern crate time;
#[macro_use] #[macro_use]
extern crate bitflags; extern crate bitflags;
#[macro_use] #[macro_use]
@ -83,46 +83,49 @@ extern crate failure;
extern crate lazy_static; extern crate lazy_static;
#[macro_use] #[macro_use]
extern crate futures; extern crate futures;
extern crate futures_cpupool;
extern crate tokio_io;
extern crate tokio_core;
extern crate mio;
extern crate net2;
extern crate cookie; extern crate cookie;
extern crate futures_cpupool;
extern crate http as modhttp; extern crate http as modhttp;
extern crate httparse;
extern crate http_range; extern crate http_range;
extern crate httparse;
extern crate language_tags;
extern crate libc;
extern crate mime; extern crate mime;
extern crate mime_guess; extern crate mime_guess;
extern crate language_tags; extern crate mio;
extern crate net2;
extern crate rand; extern crate rand;
extern crate tokio_core;
extern crate tokio_io;
extern crate url; extern crate url;
extern crate libc; #[macro_use]
#[macro_use] extern crate serde; extern crate serde;
extern crate serde_json; #[cfg(feature = "brotli")]
extern crate serde_urlencoded;
extern crate flate2;
#[cfg(feature="brotli")]
extern crate brotli2; extern crate brotli2;
extern crate encoding; extern crate encoding;
extern crate percent_encoding; extern crate flate2;
extern crate smallvec;
extern crate num_cpus;
extern crate h2 as http2; extern crate h2 as http2;
extern crate num_cpus;
extern crate percent_encoding;
extern crate serde_json;
extern crate serde_urlencoded;
extern crate smallvec;
extern crate trust_dns_resolver; extern crate trust_dns_resolver;
#[macro_use] extern crate actix; #[macro_use]
extern crate actix;
#[cfg(test)] #[cfg(test)]
#[macro_use] extern crate serde_derive; #[macro_use]
extern crate serde_derive;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
extern crate native_tls; extern crate native_tls;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
extern crate tokio_tls; extern crate tokio_tls;
#[cfg(feature="openssl")] #[cfg(feature = "openssl")]
extern crate openssl; extern crate openssl;
#[cfg(feature="openssl")] #[cfg(feature = "openssl")]
extern crate tokio_openssl; extern crate tokio_openssl;
mod application; mod application;
@ -138,33 +141,33 @@ mod httprequest;
mod httpresponse; mod httpresponse;
mod info; mod info;
mod json; mod json;
mod route;
mod router;
mod resource;
mod param; mod param;
mod payload; mod payload;
mod pipeline; mod pipeline;
mod resource;
mod route;
mod router;
mod with; mod with;
pub mod client; pub mod client;
pub mod fs;
pub mod ws;
pub mod error; pub mod error;
pub mod multipart; pub mod fs;
pub mod middleware; pub mod middleware;
pub mod multipart;
pub mod pred; pub mod pred;
pub mod test;
pub mod server; pub mod server;
pub use extractor::{Path, Form, Query}; pub mod test;
pub use error::{Error, Result, ResponseError}; pub mod ws;
pub use body::{Body, Binary};
pub use json::Json;
pub use application::App; pub use application::App;
pub use body::{Binary, Body};
pub use context::HttpContext;
pub use error::{Error, ResponseError, Result};
pub use extractor::{Form, Path, Query};
pub use handler::{AsyncResponder, Either, FromRequest, FutureResponse, Responder, State};
pub use httpmessage::HttpMessage; pub use httpmessage::HttpMessage;
pub use httprequest::HttpRequest; pub use httprequest::HttpRequest;
pub use httpresponse::HttpResponse; pub use httpresponse::HttpResponse;
pub use handler::{Either, Responder, AsyncResponder, FromRequest, FutureResponse, State}; pub use json::Json;
pub use context::HttpContext;
#[doc(hidden)] #[doc(hidden)]
pub mod httpcodes; pub mod httpcodes;
@ -173,39 +176,39 @@ pub mod httpcodes;
#[allow(deprecated)] #[allow(deprecated)]
pub use application::Application; pub use application::Application;
#[cfg(feature="openssl")] #[cfg(feature = "openssl")]
pub(crate) const HAS_OPENSSL: bool = true; pub(crate) const HAS_OPENSSL: bool = true;
#[cfg(not(feature="openssl"))] #[cfg(not(feature = "openssl"))]
pub(crate) const HAS_OPENSSL: bool = false; pub(crate) const HAS_OPENSSL: bool = false;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
pub(crate) const HAS_TLS: bool = true; pub(crate) const HAS_TLS: bool = true;
#[cfg(not(feature="tls"))] #[cfg(not(feature = "tls"))]
pub(crate) const HAS_TLS: bool = false; pub(crate) const HAS_TLS: bool = false;
pub mod dev { pub mod dev {
//! The `actix-web` prelude for library developers //! The `actix-web` prelude for library developers
//! //!
//! The purpose of this module is to alleviate imports of many common actix traits //! The purpose of this module is to alleviate imports of many common actix
//! by adding a glob import to the top of actix heavy modules: //! traits by adding a glob import to the top of actix heavy modules:
//! //!
//! ``` //! ```
//! # #![allow(unused_imports)] //! # #![allow(unused_imports)]
//! use actix_web::dev::*; //! use actix_web::dev::*;
//! ``` //! ```
pub use body::BodyStream; pub use body::BodyStream;
pub use context::Drain; pub use context::Drain;
pub use json::{JsonBody, JsonConfig};
pub use info::ConnectionInfo;
pub use handler::{Handler, Reply};
pub use extractor::{FormConfig, PayloadConfig}; pub use extractor::{FormConfig, PayloadConfig};
pub use route::Route; pub use handler::{Handler, Reply};
pub use router::{Router, Resource, ResourceType}; pub use httpmessage::{MessageBody, UrlEncoded};
pub use resource::ResourceHandler;
pub use param::{FromParam, Params};
pub use httpmessage::{UrlEncoded, MessageBody};
pub use httpresponse::HttpResponseBuilder; pub use httpresponse::HttpResponseBuilder;
pub use info::ConnectionInfo;
pub use json::{JsonBody, JsonConfig};
pub use param::{FromParam, Params};
pub use resource::ResourceHandler;
pub use route::Route;
pub use router::{Resource, ResourceType, Router};
} }
pub mod http { pub mod http {
@ -215,15 +218,15 @@ pub mod http {
pub use modhttp::{Method, StatusCode, Version}; pub use modhttp::{Method, StatusCode, Version};
#[doc(hidden)] #[doc(hidden)]
pub use modhttp::{uri, Uri, Error, Extensions, HeaderMap, HttpTryFrom}; pub use modhttp::{uri, Error, Extensions, HeaderMap, HttpTryFrom, Uri};
pub use http_range::HttpRange;
pub use cookie::{Cookie, CookieBuilder}; pub use cookie::{Cookie, CookieBuilder};
pub use http_range::HttpRange;
pub use helpers::NormalizePath; pub use helpers::NormalizePath;
pub mod header { pub mod header {
pub use ::header::*; pub use header::*;
} }
pub use header::ContentEncoding; pub use header::ContentEncoding;
pub use httpresponse::ConnectionType; pub use httpresponse::ConnectionType;

View File

@ -7,7 +7,8 @@
//! //!
//! 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. //! 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
//! 2. Use any of the builder methods to set fields in the backend. //! 2. Use any of the builder methods to set fields in the backend.
//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend. //! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
//! constructed backend.
//! //!
//! Cors middleware could be used as parameter for `App::middleware()` or //! Cors middleware could be used as parameter for `App::middleware()` or
//! `ResourceHandler::middleware()` methods. But you have to use //! `ResourceHandler::middleware()` methods. But you have to use
@ -40,65 +41,69 @@
//! .register()); //! .register());
//! } //! }
//! ``` //! ```
//! In this example custom *CORS* middleware get registered for "/index.html" endpoint. //! In this example custom *CORS* middleware get registered for "/index.html"
//! endpoint.
//! //!
//! Cors middleware automatically handle *OPTIONS* preflight request. //! Cors middleware automatically handle *OPTIONS* preflight request.
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc; use std::rc::Rc;
use http::{self, Method, HttpTryFrom, Uri, StatusCode};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use http::{self, HttpTryFrom, Method, StatusCode, Uri};
use application::App; use application::App;
use error::{Result, ResponseError}; use error::{ResponseError, Result};
use resource::ResourceHandler;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Middleware, Response, Started}; use middleware::{Middleware, Response, Started};
use resource::ResourceHandler;
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Fail)] #[derive(Debug, Fail)]
pub enum CorsError { pub enum CorsError {
/// The HTTP request header `Origin` is required but was not provided /// The HTTP request header `Origin` is required but was not provided
#[fail(display="The HTTP request header `Origin` is required but was not provided")] #[fail(display = "The HTTP request header `Origin` is required but was not provided")]
MissingOrigin, MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly. /// The HTTP request header `Origin` could not be parsed correctly.
#[fail(display="The HTTP request header `Origin` could not be parsed correctly.")] #[fail(display = "The HTTP request header `Origin` could not be parsed correctly.")]
BadOrigin, BadOrigin,
/// The request header `Access-Control-Request-Method` is required but is missing /// The request header `Access-Control-Request-Method` is required but is
#[fail(display="The request header `Access-Control-Request-Method` is required but is missing")] /// missing
#[fail(display = "The request header `Access-Control-Request-Method` is required but is missing")]
MissingRequestMethod, MissingRequestMethod,
/// The request header `Access-Control-Request-Method` has an invalid value /// The request header `Access-Control-Request-Method` has an invalid value
#[fail(display="The request header `Access-Control-Request-Method` has an invalid value")] #[fail(display = "The request header `Access-Control-Request-Method` has an invalid value")]
BadRequestMethod, BadRequestMethod,
/// The request header `Access-Control-Request-Headers` has an invalid value /// The request header `Access-Control-Request-Headers` has an invalid
#[fail(display="The request header `Access-Control-Request-Headers` has an invalid value")] /// value
#[fail(display = "The request header `Access-Control-Request-Headers` has an invalid value")]
BadRequestHeaders, BadRequestHeaders,
/// The request header `Access-Control-Request-Headers` is required but is missing. /// The request header `Access-Control-Request-Headers` is required but is
#[fail(display="The request header `Access-Control-Request-Headers` is required but is /// missing.
#[fail(display = "The request header `Access-Control-Request-Headers` is required but is
missing")] missing")]
MissingRequestHeaders, MissingRequestHeaders,
/// Origin is not allowed to make this request /// Origin is not allowed to make this request
#[fail(display="Origin is not allowed to make this request")] #[fail(display = "Origin is not allowed to make this request")]
OriginNotAllowed, OriginNotAllowed,
/// Requested method is not allowed /// Requested method is not allowed
#[fail(display="Requested method is not allowed")] #[fail(display = "Requested method is not allowed")]
MethodNotAllowed, MethodNotAllowed,
/// One or more headers requested are not allowed /// One or more headers requested are not allowed
#[fail(display="One or more headers requested are not allowed")] #[fail(display = "One or more headers requested are not allowed")]
HeadersNotAllowed, HeadersNotAllowed,
} }
impl ResponseError for CorsError { impl ResponseError for CorsError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self)) HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self))
} }
} }
/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). /// An enum signifying that some of type T is allowed, or `All` (everything is
/// allowed).
/// ///
/// `Default` is implemented for this enum and is `All`. /// `Default` is implemented for this enum and is `All`.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
@ -166,9 +171,16 @@ impl Default for Cors {
origins: AllOrSome::default(), origins: AllOrSome::default(),
origins_str: None, origins_str: None,
methods: HashSet::from_iter( methods: HashSet::from_iter(
vec![Method::GET, Method::HEAD, vec![
Method::POST, Method::OPTIONS, Method::PUT, Method::GET,
Method::PATCH, Method::DELETE].into_iter()), Method::HEAD,
Method::POST,
Method::OPTIONS,
Method::PUT,
Method::PATCH,
Method::DELETE,
].into_iter(),
),
headers: AllOrSome::All, headers: AllOrSome::All,
expose_hdrs: None, expose_hdrs: None,
max_age: None, max_age: None,
@ -177,7 +189,9 @@ impl Default for Cors {
supports_credentials: false, supports_credentials: false,
vary_header: true, vary_header: true,
}; };
Cors{inner: Rc::new(inner)} Cors {
inner: Rc::new(inner),
}
} }
} }
@ -247,11 +261,13 @@ impl Cors {
/// This method register cors middleware with resource and /// This method register cors middleware with resource and
/// adds route for *OPTIONS* preflight requests. /// adds route for *OPTIONS* preflight requests.
/// ///
/// It is possible to register *Cors* middleware with `ResourceHandler::middleware()` /// It is possible to register *Cors* middleware with
/// method, but in that case *Cors* middleware wont be able to handle *OPTIONS* /// `ResourceHandler::middleware()` method, but in that case *Cors*
/// requests. /// middleware wont be able to handle *OPTIONS* requests.
pub fn register<S: 'static>(self, resource: &mut ResourceHandler<S>) { pub fn register<S: 'static>(self, resource: &mut ResourceHandler<S>) {
resource.method(Method::OPTIONS).h(|_| HttpResponse::Ok()); resource
.method(Method::OPTIONS)
.h(|_| HttpResponse::Ok());
resource.middleware(self); resource.middleware(self);
} }
@ -260,28 +276,32 @@ impl Cors {
if let Ok(origin) = hdr.to_str() { if let Ok(origin) = hdr.to_str() {
return match self.inner.origins { return match self.inner.origins {
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_origins) => { AllOrSome::Some(ref allowed_origins) => allowed_origins
allowed_origins .get(origin)
.get(origin) .and_then(|_| Some(()))
.and_then(|_| Some(())) .ok_or_else(|| CorsError::OriginNotAllowed),
.ok_or_else(|| CorsError::OriginNotAllowed)
}
}; };
} }
Err(CorsError::BadOrigin) Err(CorsError::BadOrigin)
} else { } else {
return match self.inner.origins { return match self.inner.origins {
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
_ => Err(CorsError::MissingOrigin) _ => Err(CorsError::MissingOrigin),
} };
} }
} }
fn validate_allowed_method<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> { fn validate_allowed_method<S>(
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { &self, req: &mut HttpRequest<S>
) -> Result<(), CorsError> {
if let Some(hdr) = req.headers()
.get(header::ACCESS_CONTROL_REQUEST_METHOD)
{
if let Ok(meth) = hdr.to_str() { if let Ok(meth) = hdr.to_str() {
if let Ok(method) = Method::try_from(meth) { if let Ok(method) = Method::try_from(meth) {
return self.inner.methods.get(&method) return self.inner
.methods
.get(&method)
.and_then(|_| Some(())) .and_then(|_| Some(()))
.ok_or_else(|| CorsError::MethodNotAllowed); .ok_or_else(|| CorsError::MethodNotAllowed);
} }
@ -292,24 +312,28 @@ impl Cors {
} }
} }
fn validate_allowed_headers<S>(&self, req: &mut HttpRequest<S>) -> Result<(), CorsError> { fn validate_allowed_headers<S>(
&self, req: &mut HttpRequest<S>
) -> Result<(), CorsError> {
match self.inner.headers { match self.inner.headers {
AllOrSome::All => Ok(()), AllOrSome::All => Ok(()),
AllOrSome::Some(ref allowed_headers) => { AllOrSome::Some(ref allowed_headers) => {
if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { if let Some(hdr) = req.headers()
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
{
if let Ok(headers) = hdr.to_str() { if let Ok(headers) = hdr.to_str() {
let mut hdrs = HashSet::new(); let mut hdrs = HashSet::new();
for hdr in headers.split(',') { for hdr in headers.split(',') {
match HeaderName::try_from(hdr.trim()) { match HeaderName::try_from(hdr.trim()) {
Ok(hdr) => hdrs.insert(hdr), Ok(hdr) => hdrs.insert(hdr),
Err(_) => return Err(CorsError::BadRequestHeaders) Err(_) => return Err(CorsError::BadRequestHeaders),
}; };
} }
if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) { if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) {
return Err(CorsError::HeadersNotAllowed) return Err(CorsError::HeadersNotAllowed);
} }
return Ok(()) return Ok(());
} }
Err(CorsError::BadRequestHeaders) Err(CorsError::BadRequestHeaders)
} else { } else {
@ -321,7 +345,6 @@ impl Cors {
} }
impl<S> Middleware<S> for Cors { impl<S> Middleware<S> for Cors {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
if self.inner.preflight && Method::OPTIONS == *req.method() { if self.inner.preflight && Method::OPTIONS == *req.method() {
self.validate_origin(req)?; self.validate_origin(req)?;
@ -330,9 +353,17 @@ impl<S> Middleware<S> for Cors {
// allowed headers // allowed headers
let headers = if let Some(headers) = self.inner.headers.as_ref() { let headers = if let Some(headers) = self.inner.headers.as_ref() {
Some(HeaderValue::try_from(&headers.iter().fold( Some(
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap()) HeaderValue::try_from(
} else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { &headers
.iter()
.fold(String::new(), |s, v| s + "," + v.as_str())
.as_str()[1..],
).unwrap(),
)
} else if let Some(hdr) = req.headers()
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
{
Some(hdr.clone()) Some(hdr.clone())
} else { } else {
None None
@ -342,31 +373,44 @@ impl<S> Middleware<S> for Cors {
HttpResponse::Ok() HttpResponse::Ok()
.if_some(self.inner.max_age.as_ref(), |max_age, resp| { .if_some(self.inner.max_age.as_ref(), |max_age, resp| {
let _ = resp.header( let _ = resp.header(
header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) header::ACCESS_CONTROL_MAX_AGE,
format!("{}", max_age).as_str(),
);
})
.if_some(headers, |headers, resp| { .if_some(headers, |headers, resp| {
let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); }) let _ =
resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
})
.if_true(self.inner.origins.is_all(), |resp| { .if_true(self.inner.origins.is_all(), |resp| {
if self.inner.send_wildcard { if self.inner.send_wildcard {
resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
} else { } else {
let origin = req.headers().get(header::ORIGIN).unwrap(); let origin = req.headers().get(header::ORIGIN).unwrap();
resp.header( resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
} }
}) })
.if_true(self.inner.origins.is_some(), |resp| { .if_true(self.inner.origins.is_some(), |resp| {
resp.header( resp.header(
header::ACCESS_CONTROL_ALLOW_ORIGIN, header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.inner.origins_str.as_ref().unwrap().clone()); self.inner.origins_str.as_ref().unwrap().clone(),
);
}) })
.if_true(self.inner.supports_credentials, |resp| { .if_true(self.inner.supports_credentials, |resp| {
resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}) })
.header( .header(
header::ACCESS_CONTROL_ALLOW_METHODS, header::ACCESS_CONTROL_ALLOW_METHODS,
&self.inner.methods.iter().fold( &self.inner
String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]) .methods
.finish())) .iter()
.fold(String::new(), |s, v| s + "," + v.as_str())
.as_str()[1..],
)
.finish(),
))
} else { } else {
self.validate_origin(req)?; self.validate_origin(req)?;
@ -374,32 +418,40 @@ impl<S> Middleware<S> for Cors {
} }
} }
fn response(&self, req: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> { fn response(
&self, req: &mut HttpRequest<S>, mut resp: HttpResponse
) -> Result<Response> {
match self.inner.origins { match self.inner.origins {
AllOrSome::All => { AllOrSome::All => {
if self.inner.send_wildcard { if self.inner.send_wildcard {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("*"),
);
} else if let Some(origin) = req.headers().get(header::ORIGIN) { } else if let Some(origin) = req.headers().get(header::ORIGIN) {
resp.headers_mut().insert( resp.headers_mut()
header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
} }
} }
AllOrSome::Some(_) => { AllOrSome::Some(_) => {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN, header::ACCESS_CONTROL_ALLOW_ORIGIN,
self.inner.origins_str.as_ref().unwrap().clone()); self.inner.origins_str.as_ref().unwrap().clone(),
);
} }
} }
if let Some(ref expose) = self.inner.expose_hdrs { if let Some(ref expose) = self.inner.expose_hdrs {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS, header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::try_from(expose.as_str()).unwrap()); HeaderValue::try_from(expose.as_str()).unwrap(),
);
} }
if self.inner.supports_credentials { if self.inner.supports_credentials {
resp.headers_mut().insert( resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
} }
if self.inner.vary_header { if self.inner.vary_header {
let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) { let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) {
@ -416,13 +468,15 @@ impl<S> Middleware<S> for Cors {
} }
} }
/// Structure that follows the builder pattern for building `Cors` middleware structs. /// Structure that follows the builder pattern for building `Cors` middleware
/// structs.
/// ///
/// To construct a cors: /// To construct a cors:
/// ///
/// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. /// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
/// 2. Use any of the builder methods to set fields in the backend. /// 2. Use any of the builder methods to set fields in the backend.
/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend. /// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
/// constructed backend.
/// ///
/// # Example /// # Example
/// ///
@ -442,7 +496,7 @@ impl<S> Middleware<S> for Cors {
/// .finish(); /// .finish();
/// # } /// # }
/// ``` /// ```
pub struct CorsBuilder<S=()> { pub struct CorsBuilder<S = ()> {
cors: Option<Inner>, cors: Option<Inner>,
methods: bool, methods: bool,
error: Option<http::Error>, error: Option<http::Error>,
@ -451,26 +505,26 @@ pub struct CorsBuilder<S=()> {
app: Option<App<S>>, app: Option<App<S>>,
} }
fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<http::Error>) fn cors<'a>(
-> Option<&'a mut Inner> parts: &'a mut Option<Inner>, err: &Option<http::Error>
{ ) -> Option<&'a mut Inner> {
if err.is_some() { if err.is_some() {
return None return None;
} }
parts.as_mut() parts.as_mut()
} }
impl<S: 'static> CorsBuilder<S> { impl<S: 'static> CorsBuilder<S> {
/// Add an origin that are allowed to make requests. /// Add an origin that are allowed to make requests.
/// Will be verified against the `Origin` request header. /// Will be verified against the `Origin` request header.
/// ///
/// When `All` is set, and `send_wildcard` is set, "*" will be sent in /// When `All` is set, and `send_wildcard` is set, "*" will be sent in
/// the `Access-Control-Allow-Origin` response header. Otherwise, the client's `Origin` request /// the `Access-Control-Allow-Origin` response header. Otherwise, the
/// header will be echoed back in the `Access-Control-Allow-Origin` response header. /// client's `Origin` request header will be echoed back in the
/// `Access-Control-Allow-Origin` response header.
/// ///
/// When `Some` is set, the client's `Origin` request header will be checked in a /// When `Some` is set, the client's `Origin` request header will be
/// case-sensitive manner. /// checked in a case-sensitive manner.
/// ///
/// This is the `list of origins` in the /// This is the `list of origins` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
@ -497,15 +551,17 @@ impl<S: 'static> CorsBuilder<S> {
self self
} }
/// Set a list of methods which the allowed origins are allowed to access for /// Set a list of methods which the allowed origins are allowed to access
/// requests. /// for requests.
/// ///
/// This is the `list of methods` in the /// This is the `list of methods` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
pub fn allowed_methods<U, M>(&mut self, methods: U) -> &mut CorsBuilder<S> pub fn allowed_methods<U, M>(&mut self, methods: U) -> &mut CorsBuilder<S>
where U: IntoIterator<Item=M>, Method: HttpTryFrom<M> where
U: IntoIterator<Item = M>,
Method: HttpTryFrom<M>,
{ {
self.methods = true; self.methods = true;
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
@ -513,20 +569,21 @@ impl<S: 'static> CorsBuilder<S> {
match Method::try_from(m) { match Method::try_from(m) {
Ok(method) => { Ok(method) => {
cors.methods.insert(method); cors.methods.insert(method);
}, }
Err(e) => { Err(e) => {
self.error = Some(e.into()); self.error = Some(e.into());
break break;
} }
} }
}; }
} }
self self
} }
/// Set an allowed header /// Set an allowed header
pub fn allowed_header<H>(&mut self, header: H) -> &mut CorsBuilder<S> pub fn allowed_header<H>(&mut self, header: H) -> &mut CorsBuilder<S>
where HeaderName: HttpTryFrom<H> where
HeaderName: HttpTryFrom<H>,
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match HeaderName::try_from(header) { match HeaderName::try_from(header) {
@ -547,15 +604,18 @@ impl<S: 'static> CorsBuilder<S> {
/// Set a list of header field names which can be used when /// Set a list of header field names which can be used when
/// this resource is accessed by allowed origins. /// this resource is accessed by allowed origins.
/// ///
/// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers` /// If `All` is set, whatever is requested by the client in
/// will be echoed back in the `Access-Control-Allow-Headers` header. /// `Access-Control-Request-Headers` will be echoed back in the
/// `Access-Control-Allow-Headers` header.
/// ///
/// This is the `list of headers` in the /// This is the `list of headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// Defaults to `All`. /// Defaults to `All`.
pub fn allowed_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S> pub fn allowed_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S>
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H> where
U: IntoIterator<Item = H>,
HeaderName: HttpTryFrom<H>,
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
for h in headers { for h in headers {
@ -570,32 +630,35 @@ impl<S: 'static> CorsBuilder<S> {
} }
Err(e) => { Err(e) => {
self.error = Some(e.into()); self.error = Some(e.into());
break break;
} }
} }
}; }
} }
self self
} }
/// Set a list of headers which are safe to expose to the API of a CORS API specification. /// Set a list of headers which are safe to expose to the API of a CORS API
/// This corresponds to the `Access-Control-Expose-Headers` response header. /// specification. This corresponds to the
/// `Access-Control-Expose-Headers` response header.
/// ///
/// This is the `list of exposed headers` in the /// This is the `list of exposed headers` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// This defaults to an empty set. /// This defaults to an empty set.
pub fn expose_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S> pub fn expose_headers<U, H>(&mut self, headers: U) -> &mut CorsBuilder<S>
where U: IntoIterator<Item=H>, HeaderName: HttpTryFrom<H> where
U: IntoIterator<Item = H>,
HeaderName: HttpTryFrom<H>,
{ {
for h in headers { for h in headers {
match HeaderName::try_from(h) { match HeaderName::try_from(h) {
Ok(method) => { Ok(method) => {
self.expose_hdrs.insert(method); self.expose_hdrs.insert(method);
}, }
Err(e) => { Err(e) => {
self.error = Some(e.into()); self.error = Some(e.into());
break break;
} }
} }
} }
@ -615,16 +678,17 @@ impl<S: 'static> CorsBuilder<S> {
/// Set a wildcard origins /// Set a wildcard origins
/// ///
/// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard /// If send wildcard is set and the `allowed_origins` parameter is `All`, a
/// `Access-Control-Allow-Origin` response header is sent, rather than the requests /// wildcard `Access-Control-Allow-Origin` response header is sent,
/// `Origin` header. /// rather than the requests `Origin` header.
/// ///
/// This is the `supports credentials flag` in the /// This is the `supports credentials flag` in the
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and /// This **CANNOT** be used in conjunction with `allowed_origins` set to
/// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result /// `All` and `allow_credentials` set to `true`. Depending on the mode
/// in an `Error::CredentialsWithWildcardOrigin` error during actix launch or runtime. /// of usage, this will either result in an `Error::
/// CredentialsWithWildcardOrigin` error during actix launch or runtime.
/// ///
/// Defaults to `false`. /// Defaults to `false`.
pub fn send_wildcard(&mut self) -> &mut CorsBuilder<S> { pub fn send_wildcard(&mut self) -> &mut CorsBuilder<S> {
@ -636,11 +700,12 @@ impl<S: 'static> CorsBuilder<S> {
/// Allows users to make authenticated requests /// Allows users to make authenticated requests
/// ///
/// If true, injects the `Access-Control-Allow-Credentials` header in responses. /// If true, injects the `Access-Control-Allow-Credentials` header in
/// This allows cookies and credentials to be submitted across domains. /// responses. This allows cookies and credentials to be submitted
/// across domains.
/// ///
/// This option cannot be used in conjunction with an `allowed_origin` set to `All` /// This option cannot be used in conjunction with an `allowed_origin` set
/// and `send_wildcards` set to `true`. /// to `All` and `send_wildcards` set to `true`.
/// ///
/// Defaults to `false`. /// Defaults to `false`.
/// ///
@ -713,7 +778,8 @@ impl<S: 'static> CorsBuilder<S> {
/// } /// }
/// ``` /// ```
pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut CorsBuilder<S> pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut CorsBuilder<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{ {
// add resource handler // add resource handler
let mut handler = ResourceHandler::default(); let mut handler = ResourceHandler::default();
@ -725,9 +791,15 @@ impl<S: 'static> CorsBuilder<S> {
fn construct(&mut self) -> Cors { fn construct(&mut self) -> Cors {
if !self.methods { if !self.methods {
self.allowed_methods(vec![Method::GET, Method::HEAD, self.allowed_methods(vec![
Method::POST, Method::OPTIONS, Method::PUT, Method::GET,
Method::PATCH, Method::DELETE]); Method::HEAD,
Method::POST,
Method::OPTIONS,
Method::PUT,
Method::PATCH,
Method::DELETE,
]);
} }
if let Some(e) = self.error.take() { if let Some(e) = self.error.take() {
@ -741,16 +813,23 @@ impl<S: 'static> CorsBuilder<S> {
} }
if let AllOrSome::Some(ref origins) = cors.origins { if let AllOrSome::Some(ref origins) = cors.origins {
let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v)); let s = origins
.iter()
.fold(String::new(), |s, v| s + &format!("{}", v));
cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap()); cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap());
} }
if !self.expose_hdrs.is_empty() { if !self.expose_hdrs.is_empty() {
cors.expose_hdrs = Some( cors.expose_hdrs = Some(
self.expose_hdrs.iter().fold( self.expose_hdrs
String::new(), |s, v| s + v.as_str())[1..].to_owned()); .iter()
.fold(String::new(), |s, v| s + v.as_str())[1..]
.to_owned(),
);
}
Cors {
inner: Rc::new(cors),
} }
Cors{inner: Rc::new(cors)}
} }
/// Finishes building and returns the built `Cors` instance. /// Finishes building and returns the built `Cors` instance.
@ -758,13 +837,16 @@ impl<S: 'static> CorsBuilder<S> {
/// This method panics in case of any configuration error. /// This method panics in case of any configuration error.
pub fn finish(&mut self) -> Cors { pub fn finish(&mut self) -> Cors {
if !self.resources.is_empty() { if !self.resources.is_empty() {
panic!("CorsBuilder::resource() was used, panic!(
to construct CORS `.register(app)` method should be used"); "CorsBuilder::resource() was used,
to construct CORS `.register(app)` method should be used"
);
} }
self.construct() self.construct()
} }
/// Finishes building Cors middleware and register middleware for application /// Finishes building Cors middleware and register middleware for
/// application
/// ///
/// This method panics in case of any configuration error or if non of /// This method panics in case of any configuration error or if non of
/// resources are registered. /// resources are registered.
@ -774,8 +856,9 @@ impl<S: 'static> CorsBuilder<S> {
} }
let cors = self.construct(); let cors = self.construct();
let mut app = self.app.take().expect( let mut app = self.app
"CorsBuilder has to be constructed with Cors::for_app(app)"); .take()
.expect("CorsBuilder has to be constructed with Cors::for_app(app)");
// register resources // register resources
for (path, mut resource) in self.resources.drain(..) { for (path, mut resource) in self.resources.drain(..) {
@ -787,7 +870,6 @@ impl<S: 'static> CorsBuilder<S> {
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -845,8 +927,8 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let cors = Cors::default(); let cors = Cors::default();
let mut req = TestRequest::with_header( let mut req =
"Origin", "https://www.example.com").finish(); TestRequest::with_header("Origin", "https://www.example.com").finish();
assert!(cors.start(&mut req).ok().unwrap().is_done()) assert!(cors.start(&mut req).ok().unwrap().is_done())
} }
@ -861,8 +943,7 @@ mod tests {
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish(); .finish();
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header("Origin", "https://www.example.com")
"Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.finish(); .finish();
@ -877,23 +958,35 @@ mod tests {
let mut req = TestRequest::with_header("Origin", "https://www.example.com") let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT") .header(
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
)
.method(Method::OPTIONS) .method(Method::OPTIONS)
.finish(); .finish();
let resp = cors.start(&mut req).unwrap().response(); let resp = cors.start(&mut req).unwrap().response();
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!( assert_eq!(
&b"3600"[..], &b"3600"[..],
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()); resp.headers()
.get(header::ACCESS_CONTROL_MAX_AGE)
.unwrap()
.as_bytes()
);
//assert_eq!( //assert_eq!(
// &b"authorization,accept,content-type"[..], // &b"authorization,accept,content-type"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes()); // resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().
//assert_eq!( // as_bytes()); assert_eq!(
// &b"POST,GET,OPTIONS"[..], // &b"POST,GET,OPTIONS"[..],
// resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes()); // resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().
// as_bytes());
Rc::get_mut(&mut cors.inner).unwrap().preflight = false; Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
assert!(cors.start(&mut req).unwrap().is_done()); assert!(cors.start(&mut req).unwrap().is_done());
@ -903,7 +996,8 @@ mod tests {
#[should_panic(expected = "MissingOrigin")] #[should_panic(expected = "MissingOrigin")]
fn test_validate_missing_origin() { fn test_validate_missing_origin() {
let cors = Cors::build() let cors = Cors::build()
.allowed_origin("https://www.example.com").finish(); .allowed_origin("https://www.example.com")
.finish();
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
cors.start(&mut req).unwrap(); cors.start(&mut req).unwrap();
@ -913,7 +1007,8 @@ mod tests {
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
let cors = Cors::build() let cors = Cors::build()
.allowed_origin("https://www.example.com").finish(); .allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.unknown.com") let mut req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET) .method(Method::GET)
@ -924,7 +1019,8 @@ mod tests {
#[test] #[test]
fn test_validate_origin() { fn test_validate_origin() {
let cors = Cors::build() let cors = Cors::build()
.allowed_origin("https://www.example.com").finish(); .allowed_origin("https://www.example.com")
.finish();
let mut req = TestRequest::with_header("Origin", "https://www.example.com") let mut req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET) .method(Method::GET)
@ -940,16 +1036,23 @@ mod tests {
let mut req = TestRequest::default().method(Method::GET).finish(); let mut req = TestRequest::default().method(Method::GET).finish();
let resp: HttpResponse = HttpResponse::Ok().into(); let resp: HttpResponse = HttpResponse::Ok().into();
let resp = cors.response(&mut req, resp).unwrap().response(); let resp = cors.response(&mut req, resp).unwrap().response();
assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none()); assert!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none()
);
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header("Origin", "https://www.example.com")
"Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.finish(); .finish();
let resp = cors.response(&mut req, resp).unwrap().response(); let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!( assert_eq!(
&b"https://www.example.com"[..], &b"https://www.example.com"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
} }
#[test] #[test]
@ -963,8 +1066,7 @@ mod tests {
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish(); .finish();
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header("Origin", "https://www.example.com")
"Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.finish(); .finish();
@ -972,10 +1074,15 @@ mod tests {
let resp = cors.response(&mut req, resp).unwrap().response(); let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!( assert_eq!(
&b"Origin"[..], &b"Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()); resp.headers().get(header::VARY).unwrap().as_bytes()
);
let resp: HttpResponse = HttpResponse::Ok() let resp: HttpResponse = HttpResponse::Ok()
.header(header::VARY, "Accept") .header(header::VARY, "Accept")
@ -983,7 +1090,8 @@ mod tests {
let resp = cors.response(&mut req, resp).unwrap().response(); let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!( assert_eq!(
&b"Accept, Origin"[..], &b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()); resp.headers().get(header::VARY).unwrap().as_bytes()
);
let cors = Cors::build() let cors = Cors::build()
.disable_vary_header() .disable_vary_header()
@ -993,25 +1101,33 @@ mod tests {
let resp = cors.response(&mut req, resp).unwrap().response(); let resp = cors.response(&mut req, resp).unwrap().response();
assert_eq!( assert_eq!(
&b"https://www.example.com"[..], &b"https://www.example.com"[..],
resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
} }
#[test] #[test]
fn cors_resource() { fn cors_resource() {
let mut srv = test::TestServer::with_factory( let mut srv = test::TestServer::with_factory(|| {
|| App::new() App::new().configure(|app| {
.configure( Cors::for_app(app)
|app| Cors::for_app(app) .allowed_origin("https://www.example.com")
.allowed_origin("https://www.example.com") .resource("/test", |r| r.f(|_| HttpResponse::Ok()))
.resource("/test", |r| r.f(|_| HttpResponse::Ok())) .register()
.register())); })
});
let request = srv.get().uri(srv.url("/test")).finish().unwrap(); let request = srv.get().uri(srv.url("/test")).finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST); assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let request = srv.get().uri(srv.url("/test")) let request = srv.get()
.header("ORIGIN", "https://www.example.com").finish().unwrap(); .uri(srv.url("/test"))
.header("ORIGIN", "https://www.example.com")
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.status(), StatusCode::OK);
} }

View File

@ -48,8 +48,8 @@ use std::borrow::Cow;
use std::collections::HashSet; use std::collections::HashSet;
use bytes::Bytes; use bytes::Bytes;
use error::{Result, ResponseError}; use error::{ResponseError, Result};
use http::{HeaderMap, HttpTryFrom, Uri, header}; use http::{header, HeaderMap, HttpTryFrom, Uri};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
@ -59,13 +59,13 @@ use middleware::{Middleware, Started};
#[derive(Debug, Fail)] #[derive(Debug, Fail)]
pub enum CsrfError { pub enum CsrfError {
/// The HTTP request header `Origin` was required but not provided. /// The HTTP request header `Origin` was required but not provided.
#[fail(display="Origin header required")] #[fail(display = "Origin header required")]
MissingOrigin, MissingOrigin,
/// The HTTP request header `Origin` could not be parsed correctly. /// The HTTP request header `Origin` could not be parsed correctly.
#[fail(display="Could not parse Origin header")] #[fail(display = "Could not parse Origin header")]
BadOrigin, BadOrigin,
/// The cross-site request was denied. /// The cross-site request was denied.
#[fail(display="Cross-site request denied")] #[fail(display = "Cross-site request denied")]
CsrDenied, CsrDenied,
} }
@ -80,15 +80,14 @@ fn uri_origin(uri: &Uri) -> Option<String> {
(Some(scheme), Some(host), Some(port)) => { (Some(scheme), Some(host), Some(port)) => {
Some(format!("{}://{}:{}", scheme, host, port)) Some(format!("{}://{}:{}", scheme, host, port))
} }
(Some(scheme), Some(host), None) => { (Some(scheme), Some(host), None) => Some(format!("{}://{}", scheme, host)),
Some(format!("{}://{}", scheme, host)) _ => None,
}
_ => None
} }
} }
fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> { fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
headers.get(header::ORIGIN) headers
.get(header::ORIGIN)
.map(|origin| { .map(|origin| {
origin origin
.to_str() .to_str()
@ -96,15 +95,14 @@ fn origin(headers: &HeaderMap) -> Option<Result<Cow<str>, CsrfError>> {
.map(|o| o.into()) .map(|o| o.into())
}) })
.or_else(|| { .or_else(|| {
headers.get(header::REFERER) headers.get(header::REFERER).map(|referer| {
.map(|referer| { Uri::try_from(Bytes::from(referer.as_bytes()))
Uri::try_from(Bytes::from(referer.as_bytes())) .ok()
.ok() .as_ref()
.as_ref() .and_then(uri_origin)
.and_then(uri_origin) .ok_or(CsrfError::BadOrigin)
.ok_or(CsrfError::BadOrigin) .map(|o| o.into())
.map(|o| o.into()) })
})
}) })
} }
@ -194,7 +192,8 @@ impl CsrfFilter {
let is_upgrade = req.headers().contains_key(header::UPGRADE); let is_upgrade = req.headers().contains_key(header::UPGRADE);
let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade);
if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with"))
{
Ok(()) Ok(())
} else if let Some(header) = origin(req.headers()) { } else if let Some(header) = origin(req.headers()) {
match header { match header {
@ -225,8 +224,7 @@ mod tests {
#[test] #[test]
fn test_safe() { fn test_safe() {
let csrf = CsrfFilter::new() let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
.allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Origin", "https://www.w3.org") let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::HEAD) .method(Method::HEAD)
@ -237,8 +235,7 @@ mod tests {
#[test] #[test]
fn test_csrf() { fn test_csrf() {
let csrf = CsrfFilter::new() let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
.allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Origin", "https://www.w3.org") let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::POST) .method(Method::POST)
@ -249,11 +246,12 @@ mod tests {
#[test] #[test]
fn test_referer() { fn test_referer() {
let csrf = CsrfFilter::new() let csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
.allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param") let mut req = TestRequest::with_header(
.method(Method::POST) "Referer",
"https://www.example.com/some/path?query=param",
).method(Method::POST)
.finish(); .finish();
assert!(csrf.start(&mut req).is_ok()); assert!(csrf.start(&mut req).is_ok());
@ -261,8 +259,7 @@ mod tests {
#[test] #[test]
fn test_upgrade() { fn test_upgrade() {
let strict_csrf = CsrfFilter::new() let strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
.allowed_origin("https://www.example.com");
let lax_csrf = CsrfFilter::new() let lax_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")

View File

@ -1,11 +1,11 @@
//! Default response headers //! Default response headers
use http::{HeaderMap, HttpTryFrom};
use http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use http::{HeaderMap, HttpTryFrom};
use error::Result; use error::Result;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Response, Middleware}; use middleware::{Middleware, Response};
/// `Middleware` for setting default response headers. /// `Middleware` for setting default response headers.
/// ///
@ -27,14 +27,17 @@ use middleware::{Response, Middleware};
/// .finish(); /// .finish();
/// } /// }
/// ``` /// ```
pub struct DefaultHeaders{ pub struct DefaultHeaders {
ct: bool, ct: bool,
headers: HeaderMap, headers: HeaderMap,
} }
impl Default for DefaultHeaders { impl Default for DefaultHeaders {
fn default() -> Self { fn default() -> Self {
DefaultHeaders{ct: false, headers: HeaderMap::new()} DefaultHeaders {
ct: false,
headers: HeaderMap::new(),
}
} }
} }
@ -48,15 +51,16 @@ impl DefaultHeaders {
#[inline] #[inline]
#[cfg_attr(feature = "cargo-clippy", allow(match_wild_err_arm))] #[cfg_attr(feature = "cargo-clippy", allow(match_wild_err_arm))]
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>, where
HeaderValue: HttpTryFrom<V> HeaderName: HttpTryFrom<K>,
HeaderValue: HttpTryFrom<V>,
{ {
match HeaderName::try_from(key) { match HeaderName::try_from(key) {
Ok(key) => { Ok(key) => match HeaderValue::try_from(value) {
match HeaderValue::try_from(value) { Ok(value) => {
Ok(value) => { self.headers.append(key, value); } self.headers.append(key, value);
Err(_) => panic!("Can not create header value"),
} }
Err(_) => panic!("Can not create header value"),
}, },
Err(_) => panic!("Can not create header name"), Err(_) => panic!("Can not create header name"),
} }
@ -71,8 +75,9 @@ impl DefaultHeaders {
} }
impl<S> Middleware<S> for DefaultHeaders { impl<S> Middleware<S> for DefaultHeaders {
fn response(
fn response(&self, _: &mut HttpRequest<S>, mut resp: HttpResponse) -> Result<Response> { &self, _: &mut HttpRequest<S>, mut resp: HttpResponse
) -> Result<Response> {
for (key, value) in self.headers.iter() { for (key, value) in self.headers.iter() {
if !resp.headers().contains_key(key) { if !resp.headers().contains_key(key) {
resp.headers_mut().insert(key, value.clone()); resp.headers_mut().insert(key, value.clone());
@ -81,7 +86,9 @@ impl<S> Middleware<S> for DefaultHeaders {
// default content-type // default content-type
if self.ct && !resp.headers().contains_key(CONTENT_TYPE) { if self.ct && !resp.headers().contains_key(CONTENT_TYPE) {
resp.headers_mut().insert( resp.headers_mut().insert(
CONTENT_TYPE, HeaderValue::from_static("application/octet-stream")); CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
} }
Ok(Response::Done(resp)) Ok(Response::Done(resp))
} }
@ -94,8 +101,7 @@ mod tests {
#[test] #[test]
fn test_default_headers() { fn test_default_headers() {
let mw = DefaultHeaders::new() let mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001");
.header(CONTENT_TYPE, "0001");
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
@ -106,7 +112,9 @@ mod tests {
}; };
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let resp = HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish(); let resp = HttpResponse::Ok()
.header(CONTENT_TYPE, "0002")
.finish();
let resp = match mw.response(&mut req, resp) { let resp = match mw.response(&mut req, resp) {
Ok(Response::Done(resp)) => resp, Ok(Response::Done(resp)) => resp,
_ => panic!(), _ => panic!(),

View File

@ -6,14 +6,13 @@ use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Middleware, Response}; use middleware::{Middleware, Response};
type ErrorHandler<S> = Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response>; type ErrorHandler<S> = Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response>;
/// `Middleware` for allowing custom handlers for responses. /// `Middleware` for allowing custom handlers for responses.
/// ///
/// You can use `ErrorHandlers::handler()` method to register a custom error handler /// You can use `ErrorHandlers::handler()` method to register a custom error
/// for specific status code. You can modify existing response or create completly new /// handler for specific status code. You can modify existing response or
/// one. /// create completly new one.
/// ///
/// ## Example /// ## Example
/// ///
@ -53,7 +52,6 @@ impl<S> Default for ErrorHandlers<S> {
} }
impl<S> ErrorHandlers<S> { impl<S> ErrorHandlers<S> {
/// Construct new `ErrorHandlers` instance /// Construct new `ErrorHandlers` instance
pub fn new() -> Self { pub fn new() -> Self {
ErrorHandlers::default() ErrorHandlers::default()
@ -61,7 +59,8 @@ impl<S> ErrorHandlers<S> {
/// Register error handler for specified status code /// Register error handler for specified status code
pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
where F: Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response> + 'static where
F: Fn(&mut HttpRequest<S>, HttpResponse) -> Result<Response> + 'static,
{ {
self.handlers.insert(status, Box::new(handler)); self.handlers.insert(status, Box::new(handler));
self self
@ -69,8 +68,9 @@ impl<S> ErrorHandlers<S> {
} }
impl<S: 'static> Middleware<S> for ErrorHandlers<S> { impl<S: 'static> Middleware<S> for ErrorHandlers<S> {
fn response(
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> { &self, req: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<Response> {
if let Some(handler) = self.handlers.get(&resp.status()) { if let Some(handler) = self.handlers.get(&resp.status()) {
handler(req, resp) handler(req, resp)
} else { } else {
@ -90,11 +90,11 @@ mod tests {
builder.header(CONTENT_TYPE, "0001"); builder.header(CONTENT_TYPE, "0001");
Ok(Response::Done(builder.into())) Ok(Response::Done(builder.into()))
} }
#[test] #[test]
fn test_handler() { fn test_handler() {
let mw = ErrorHandlers::new() let mw =
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();
let resp = HttpResponse::InternalServerError().finish(); let resp = HttpResponse::InternalServerError().finish();

View File

@ -4,14 +4,14 @@ use std::fmt;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use libc; use libc;
use time;
use regex::Regex; use regex::Regex;
use time;
use error::Result; use error::Result;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Middleware, Started, Finished}; use middleware::{Finished, Middleware, Started};
/// `Middleware` for logging request and response info to the terminal. /// `Middleware` for logging request and response info to the terminal.
/// `Logger` middleware uses standard log crate to log information. You should /// `Logger` middleware uses standard log crate to log information. You should
@ -21,7 +21,8 @@ use middleware::{Middleware, Started, Finished};
/// ## Usage /// ## Usage
/// ///
/// Create `Logger` middleware with the specified `format`. /// Create `Logger` middleware with the specified `format`.
/// Default `Logger` could be created with `default` method, it uses the default format: /// Default `Logger` could be created with `default` method, it uses the
/// default format:
/// ///
/// ```ignore /// ```ignore
/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T /// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
@ -59,7 +60,8 @@ use middleware::{Middleware, Started, Finished};
/// ///
/// `%b` Size of response in bytes, including HTTP headers /// `%b` Size of response in bytes, including HTTP headers
/// ///
/// `%T` Time taken to serve the request, in seconds with floating fraction in .06f format /// `%T` Time taken to serve the request, in seconds with floating fraction in
/// .06f format
/// ///
/// `%D` Time taken to serve the request, in milliseconds /// `%D` Time taken to serve the request, in milliseconds
/// ///
@ -76,7 +78,9 @@ pub struct Logger {
impl Logger { impl Logger {
/// Create `Logger` middleware with the specified `format`. /// Create `Logger` middleware with the specified `format`.
pub fn new(format: &str) -> Logger { pub fn new(format: &str) -> Logger {
Logger { format: Format::new(format) } Logger {
format: Format::new(format),
}
} }
} }
@ -87,14 +91,15 @@ impl Default for Logger {
/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T /// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
/// ``` /// ```
fn default() -> Logger { fn default() -> Logger {
Logger { format: Format::default() } Logger {
format: Format::default(),
}
} }
} }
struct StartTime(time::Tm); struct StartTime(time::Tm);
impl Logger { impl Logger {
fn log<S>(&self, req: &mut HttpRequest<S>, resp: &HttpResponse) { fn log<S>(&self, req: &mut HttpRequest<S>, resp: &HttpResponse) {
let entry_time = req.extensions().get::<StartTime>().unwrap().0; let entry_time = req.extensions().get::<StartTime>().unwrap().0;
@ -109,7 +114,6 @@ impl Logger {
} }
impl<S> Middleware<S> for Logger { impl<S> Middleware<S> for Logger {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
req.extensions().insert(StartTime(time::now())); req.extensions().insert(StartTime(time::now()));
Ok(Started::Done) Ok(Started::Done)
@ -153,29 +157,26 @@ impl Format {
idx = m.end(); idx = m.end();
if let Some(key) = cap.get(2) { if let Some(key) = cap.get(2) {
results.push( results.push(match cap.get(3).unwrap().as_str() {
match cap.get(3).unwrap().as_str() { "i" => FormatText::RequestHeader(key.as_str().to_owned()),
"i" => FormatText::RequestHeader(key.as_str().to_owned()), "o" => FormatText::ResponseHeader(key.as_str().to_owned()),
"o" => FormatText::ResponseHeader(key.as_str().to_owned()), "e" => FormatText::EnvironHeader(key.as_str().to_owned()),
"e" => FormatText::EnvironHeader(key.as_str().to_owned()), _ => unreachable!(),
_ => unreachable!(), })
})
} else { } else {
let m = cap.get(1).unwrap(); let m = cap.get(1).unwrap();
results.push( results.push(match m.as_str() {
match m.as_str() { "%" => FormatText::Percent,
"%" => FormatText::Percent, "a" => FormatText::RemoteAddr,
"a" => FormatText::RemoteAddr, "t" => FormatText::RequestTime,
"t" => FormatText::RequestTime, "P" => FormatText::Pid,
"P" => FormatText::Pid, "r" => FormatText::RequestLine,
"r" => FormatText::RequestLine, "s" => FormatText::ResponseStatus,
"s" => FormatText::ResponseStatus, "b" => FormatText::ResponseSize,
"b" => FormatText::ResponseSize, "T" => FormatText::Time,
"T" => FormatText::Time, "D" => FormatText::TimeMillis,
"D" => FormatText::TimeMillis, _ => FormatText::Str(m.as_str().to_owned()),
_ => FormatText::Str(m.as_str().to_owned()), });
}
);
} }
} }
if idx != s.len() { if idx != s.len() {
@ -207,12 +208,10 @@ pub enum FormatText {
} }
impl FormatText { impl FormatText {
fn render<S>(
fn render<S>(&self, fmt: &mut Formatter, &self, fmt: &mut Formatter, req: &HttpRequest<S>, resp: &HttpResponse,
req: &HttpRequest<S>, entry_time: time::Tm,
resp: &HttpResponse, ) -> Result<(), fmt::Error> {
entry_time: time::Tm) -> Result<(), fmt::Error>
{
match *self { match *self {
FormatText::Str(ref string) => fmt.write_str(string), FormatText::Str(ref string) => fmt.write_str(string),
FormatText::Percent => "%".fmt(fmt), FormatText::Percent => "%".fmt(fmt),
@ -220,26 +219,33 @@ impl FormatText {
if req.query_string().is_empty() { if req.query_string().is_empty() {
fmt.write_fmt(format_args!( fmt.write_fmt(format_args!(
"{} {} {:?}", "{} {} {:?}",
req.method(), req.path(), req.version())) req.method(),
req.path(),
req.version()
))
} else { } else {
fmt.write_fmt(format_args!( fmt.write_fmt(format_args!(
"{} {}?{} {:?}", "{} {}?{} {:?}",
req.method(), req.path(), req.query_string(), req.version())) req.method(),
req.path(),
req.query_string(),
req.version()
))
} }
}, }
FormatText::ResponseStatus => resp.status().as_u16().fmt(fmt), FormatText::ResponseStatus => resp.status().as_u16().fmt(fmt),
FormatText::ResponseSize => resp.response_size().fmt(fmt), FormatText::ResponseSize => resp.response_size().fmt(fmt),
FormatText::Pid => unsafe{libc::getpid().fmt(fmt)}, FormatText::Pid => unsafe { libc::getpid().fmt(fmt) },
FormatText::Time => { FormatText::Time => {
let rt = time::now() - entry_time; let rt = time::now() - entry_time;
let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000_000.0; let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000_000.0;
fmt.write_fmt(format_args!("{:.6}", rt)) fmt.write_fmt(format_args!("{:.6}", rt))
}, }
FormatText::TimeMillis => { FormatText::TimeMillis => {
let rt = time::now() - entry_time; let rt = time::now() - entry_time;
let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000.0; let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000.0;
fmt.write_fmt(format_args!("{:.6}", rt)) fmt.write_fmt(format_args!("{:.6}", rt))
}, }
FormatText::RemoteAddr => { FormatText::RemoteAddr => {
if let Some(remote) = req.connection_info().remote() { if let Some(remote) = req.connection_info().remote() {
return remote.fmt(fmt); return remote.fmt(fmt);
@ -247,14 +253,17 @@ impl FormatText {
"-".fmt(fmt) "-".fmt(fmt)
} }
} }
FormatText::RequestTime => { FormatText::RequestTime => entry_time
entry_time.strftime("[%d/%b/%Y:%H:%M:%S %z]") .strftime("[%d/%b/%Y:%H:%M:%S %z]")
.unwrap() .unwrap()
.fmt(fmt) .fmt(fmt),
}
FormatText::RequestHeader(ref name) => { FormatText::RequestHeader(ref name) => {
let s = if let Some(val) = req.headers().get(name) { let s = if let Some(val) = req.headers().get(name) {
if let Ok(s) = val.to_str() { s } else { "-" } if let Ok(s) = val.to_str() {
s
} else {
"-"
}
} else { } else {
"-" "-"
}; };
@ -262,7 +271,11 @@ impl FormatText {
} }
FormatText::ResponseHeader(ref name) => { FormatText::ResponseHeader(ref name) => {
let s = if let Some(val) = resp.headers().get(name) { let s = if let Some(val) = resp.headers().get(name) {
if let Ok(s) = val.to_str() { s } else { "-" } if let Ok(s) = val.to_str() {
s
} else {
"-"
}
} else { } else {
"-" "-"
}; };
@ -279,8 +292,7 @@ impl FormatText {
} }
} }
pub(crate) struct FormatDisplay<'a>( pub(crate) struct FormatDisplay<'a>(&'a Fn(&mut Formatter) -> Result<(), fmt::Error>);
&'a Fn(&mut Formatter) -> Result<(), fmt::Error>);
impl<'a> fmt::Display for FormatDisplay<'a> { impl<'a> fmt::Display for FormatDisplay<'a> {
fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
@ -291,19 +303,27 @@ impl<'a> fmt::Display for FormatDisplay<'a> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use http::header::{self, HeaderMap};
use http::{Method, StatusCode, Uri, Version};
use std::str::FromStr; use std::str::FromStr;
use time; use time;
use http::{Method, Version, StatusCode, Uri};
use http::header::{self, HeaderMap};
#[test] #[test]
fn test_logger() { fn test_logger() {
let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test");
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB")); headers.insert(
header::USER_AGENT,
header::HeaderValue::from_static("ACTIX-WEB"),
);
let mut req = HttpRequest::new( let mut req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let resp = HttpResponse::build(StatusCode::OK) let resp = HttpResponse::build(StatusCode::OK)
.header("X-Test", "ttt") .header("X-Test", "ttt")
.force_close() .force_close()
@ -333,10 +353,20 @@ mod tests {
let format = Format::default(); let format = Format::default();
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB")); headers.insert(
header::USER_AGENT,
header::HeaderValue::from_static("ACTIX-WEB"),
);
let req = HttpRequest::new( let req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); Method::GET,
let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let resp = HttpResponse::build(StatusCode::OK)
.force_close()
.finish();
let entry_time = time::now(); let entry_time = time::now();
let render = |fmt: &mut Formatter| { let render = |fmt: &mut Formatter| {
@ -351,9 +381,15 @@ mod tests {
assert!(s.contains("ACTIX-WEB")); assert!(s.contains("ACTIX-WEB"));
let req = HttpRequest::new( let req = HttpRequest::new(
Method::GET, Uri::from_str("/?test").unwrap(), Method::GET,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/?test").unwrap(),
let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); Version::HTTP_11,
HeaderMap::new(),
None,
);
let resp = HttpResponse::build(StatusCode::OK)
.force_close()
.finish();
let entry_time = time::now(); let entry_time = time::now();
let render = |fmt: &mut Formatter| { let render = |fmt: &mut Formatter| {

View File

@ -7,19 +7,19 @@ use httpresponse::HttpResponse;
mod logger; mod logger;
#[cfg(feature = "session")]
mod session;
mod defaultheaders;
mod errhandlers;
pub mod cors; pub mod cors;
pub mod csrf; pub mod csrf;
pub use self::logger::Logger; mod defaultheaders;
pub use self::errhandlers::ErrorHandlers; mod errhandlers;
#[cfg(feature = "session")]
mod session;
pub use self::defaultheaders::DefaultHeaders; pub use self::defaultheaders::DefaultHeaders;
pub use self::errhandlers::ErrorHandlers;
pub use self::logger::Logger;
#[cfg(feature = "session")] #[cfg(feature = "session")]
pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage, pub use self::session::{CookieSessionBackend, CookieSessionError, RequestSession,
CookieSessionError, CookieSessionBackend}; Session, SessionBackend, SessionImpl, SessionStorage};
/// Middleware start result /// Middleware start result
pub enum Started { pub enum Started {
@ -29,7 +29,7 @@ pub enum Started {
/// handler execution halts. /// handler execution halts.
Response(HttpResponse), Response(HttpResponse),
/// Execution completed, runs future to completion. /// Execution completed, runs future to completion.
Future(Box<Future<Item=Option<HttpResponse>, Error=Error>>), Future(Box<Future<Item = Option<HttpResponse>, Error = Error>>),
} }
/// Middleware execution result /// Middleware execution result
@ -37,7 +37,7 @@ pub enum Response {
/// New http response got generated /// New http response got generated
Done(HttpResponse), Done(HttpResponse),
/// Result is a future that resolves to a new http response /// Result is a future that resolves to a new http response
Future(Box<Future<Item=HttpResponse, Error=Error>>), Future(Box<Future<Item = HttpResponse, Error = Error>>),
} }
/// Middleware finish result /// Middleware finish result
@ -45,13 +45,12 @@ pub enum Finished {
/// Execution completed /// Execution completed
Done, Done,
/// Execution completed, but run future to completion /// Execution completed, but run future to completion
Future(Box<Future<Item=(), Error=Error>>), Future(Box<Future<Item = (), Error = Error>>),
} }
/// Middleware definition /// Middleware definition
#[allow(unused_variables)] #[allow(unused_variables)]
pub trait Middleware<S>: 'static { pub trait Middleware<S>: 'static {
/// Method is called when request is ready. It may return /// Method is called when request is ready. It may return
/// future, which should resolve before next middleware get called. /// future, which should resolve before next middleware get called.
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
@ -60,7 +59,9 @@ pub trait Middleware<S>: 'static {
/// Method is called when handler returns response, /// Method is called when handler returns response,
/// but before sending http message to peer. /// but before sending http message to peer.
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> { fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<Response> {
Ok(Response::Done(resp)) Ok(Response::Done(resp))
} }

View File

@ -1,21 +1,21 @@
use std::collections::HashMap;
use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::marker::PhantomData;
use std::collections::HashMap;
use cookie::{Cookie, CookieJar, Key};
use futures::Future;
use futures::future::{FutureResult, err as FutErr, ok as FutOk};
use http::header::{self, HeaderValue};
use serde::{Deserialize, Serialize};
use serde_json; use serde_json;
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
use serde::{Serialize, Deserialize};
use http::header::{self, HeaderValue};
use time::Duration; use time::Duration;
use cookie::{CookieJar, Cookie, Key};
use futures::Future;
use futures::future::{FutureResult, ok as FutOk, err as FutErr};
use error::{Result, Error, ResponseError}; use error::{Error, ResponseError, Result};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Middleware, Started, Response}; use middleware::{Middleware, Response, Started};
/// The helper trait to obtain your session data from a request. /// The helper trait to obtain your session data from a request.
/// ///
@ -40,14 +40,13 @@ pub trait RequestSession {
} }
impl<S> RequestSession for HttpRequest<S> { impl<S> RequestSession for HttpRequest<S> {
fn session(&mut self) -> Session { fn session(&mut self) -> Session {
if let Some(s_impl) = self.extensions().get_mut::<Arc<SessionImplBox>>() { if let Some(s_impl) = self.extensions().get_mut::<Arc<SessionImplBox>>() {
if let Some(s) = Arc::get_mut(s_impl) { if let Some(s) = Arc::get_mut(s_impl) {
return Session(s.0.as_mut()) return Session(s.0.as_mut());
} }
} }
Session(unsafe{&mut DUMMY}) Session(unsafe { &mut DUMMY })
} }
} }
@ -76,7 +75,6 @@ impl<S> RequestSession for HttpRequest<S> {
pub struct Session<'a>(&'a mut SessionImpl); pub struct Session<'a>(&'a mut SessionImpl);
impl<'a> Session<'a> { impl<'a> Session<'a> {
/// Get a `value` from the session. /// Get a `value` from the session.
pub fn get<T: Deserialize<'a>>(&'a self, key: &str) -> Result<Option<T>> { pub fn get<T: Deserialize<'a>>(&'a self, key: &str) -> Result<Option<T>> {
if let Some(s) = self.0.get(key) { if let Some(s) = self.0.get(key) {
@ -136,24 +134,25 @@ impl<S, T: SessionBackend<S>> SessionStorage<T, S> {
} }
impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> { impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> {
let mut req = req.clone(); let mut req = req.clone();
let fut = self.0.from_request(&mut req) let fut = self.0
.then(move |res| { .from_request(&mut req)
match res { .then(move |res| match res {
Ok(sess) => { Ok(sess) => {
req.extensions().insert(Arc::new(SessionImplBox(Box::new(sess)))); req.extensions()
FutOk(None) .insert(Arc::new(SessionImplBox(Box::new(sess))));
}, FutOk(None)
Err(err) => FutErr(err)
} }
Err(err) => FutErr(err),
}); });
Ok(Started::Future(Box::new(fut))) Ok(Started::Future(Box::new(fut)))
} }
fn response(&self, req: &mut HttpRequest<S>, resp: HttpResponse) -> Result<Response> { fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<Response> {
if let Some(s_box) = req.extensions().remove::<Arc<SessionImplBox>>() { if let Some(s_box) = req.extensions().remove::<Arc<SessionImplBox>>() {
s_box.0.write(resp) s_box.0.write(resp)
} else { } else {
@ -165,7 +164,6 @@ impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
/// A simple key-value storage interface that is internally used by `Session`. /// A simple key-value storage interface that is internally used by `Session`.
#[doc(hidden)] #[doc(hidden)]
pub trait SessionImpl: 'static { pub trait SessionImpl: 'static {
fn get(&self, key: &str) -> Option<&str>; fn get(&self, key: &str) -> Option<&str>;
fn set(&mut self, key: &str, value: String); fn set(&mut self, key: &str, value: String);
@ -182,7 +180,7 @@ pub trait SessionImpl: 'static {
#[doc(hidden)] #[doc(hidden)]
pub trait SessionBackend<S>: Sized + 'static { pub trait SessionBackend<S>: Sized + 'static {
type Session: SessionImpl; type Session: SessionImpl;
type ReadFuture: Future<Item=Self::Session, Error=Error>; type ReadFuture: Future<Item = Self::Session, Error = Error>;
/// Parse the session from request and load data from a storage backend. /// Parse the session from request and load data from a storage backend.
fn from_request(&self, request: &mut HttpRequest<S>) -> Self::ReadFuture; fn from_request(&self, request: &mut HttpRequest<S>) -> Self::ReadFuture;
@ -194,8 +192,9 @@ struct DummySessionImpl;
static mut DUMMY: DummySessionImpl = DummySessionImpl; static mut DUMMY: DummySessionImpl = DummySessionImpl;
impl SessionImpl for DummySessionImpl { impl SessionImpl for DummySessionImpl {
fn get(&self, _: &str) -> Option<&str> {
fn get(&self, _: &str) -> Option<&str> { None } None
}
fn set(&mut self, _: &str, _: String) {} fn set(&mut self, _: &str, _: String) {}
fn remove(&mut self, _: &str) {} fn remove(&mut self, _: &str) {}
fn clear(&mut self) {} fn clear(&mut self) {}
@ -215,17 +214,16 @@ pub struct CookieSession {
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum CookieSessionError { pub enum CookieSessionError {
/// Size of the serialized session is greater than 4000 bytes. /// Size of the serialized session is greater than 4000 bytes.
#[fail(display="Size of the serialized session is greater than 4000 bytes.")] #[fail(display = "Size of the serialized session is greater than 4000 bytes.")]
Overflow, Overflow,
/// Fail to serialize session. /// Fail to serialize session.
#[fail(display="Fail to serialize session")] #[fail(display = "Fail to serialize session")]
Serialize(JsonError), Serialize(JsonError),
} }
impl ResponseError for CookieSessionError {} impl ResponseError for CookieSessionError {}
impl SessionImpl for CookieSession { impl SessionImpl for CookieSession {
fn get(&self, key: &str) -> Option<&str> { fn get(&self, key: &str) -> Option<&str> {
if let Some(s) = self.state.get(key) { if let Some(s) = self.state.get(key) {
Some(s) Some(s)
@ -259,7 +257,7 @@ impl SessionImpl for CookieSession {
enum CookieSecurity { enum CookieSecurity {
Signed, Signed,
Private Private,
} }
struct CookieSessionInner { struct CookieSessionInner {
@ -273,7 +271,6 @@ struct CookieSessionInner {
} }
impl CookieSessionInner { impl CookieSessionInner {
fn new(key: &[u8], security: CookieSecurity) -> CookieSessionInner { fn new(key: &[u8], security: CookieSecurity) -> CookieSessionInner {
CookieSessionInner { CookieSessionInner {
security, security,
@ -286,11 +283,13 @@ impl CookieSessionInner {
} }
} }
fn set_cookie(&self, resp: &mut HttpResponse, state: &HashMap<String, String>) -> Result<()> { fn set_cookie(
let value = serde_json::to_string(&state) &self, resp: &mut HttpResponse, state: &HashMap<String, String>
.map_err(CookieSessionError::Serialize)?; ) -> Result<()> {
let value =
serde_json::to_string(&state).map_err(CookieSessionError::Serialize)?;
if value.len() > 4064 { if value.len() > 4064 {
return Err(CookieSessionError::Overflow.into()) return Err(CookieSessionError::Overflow.into());
} }
let mut cookie = Cookie::new(self.name.clone(), value); let mut cookie = Cookie::new(self.name.clone(), value);
@ -330,7 +329,9 @@ impl CookieSessionInner {
let cookie_opt = match self.security { let cookie_opt = match self.security {
CookieSecurity::Signed => jar.signed(&self.key).get(&self.name), CookieSecurity::Signed => jar.signed(&self.key).get(&self.name),
CookieSecurity::Private => jar.private(&self.key).get(&self.name), CookieSecurity::Private => {
jar.private(&self.key).get(&self.name)
}
}; };
if let Some(cookie) = cookie_opt { if let Some(cookie) = cookie_opt {
if let Ok(val) = serde_json::from_str(cookie.value()) { if let Ok(val) = serde_json::from_str(cookie.value()) {
@ -347,20 +348,24 @@ impl CookieSessionInner {
/// Use cookies for session storage. /// Use cookies for session storage.
/// ///
/// `CookieSessionBackend` creates sessions which are limited to storing /// `CookieSessionBackend` creates sessions which are limited to storing
/// fewer than 4000 bytes of data (as the payload must fit into a single cookie). /// fewer than 4000 bytes of data (as the payload must fit into a single
/// An Internal Server Error is generated if the session contains more than 4000 bytes. /// cookie). An Internal Server Error is generated if the session contains more
/// than 4000 bytes.
/// ///
/// A cookie may have a security policy of *signed* or *private*. Each has a respective `CookieSessionBackend` constructor. /// A cookie may have a security policy of *signed* or *private*. Each has a
/// respective `CookieSessionBackend` constructor.
/// ///
/// A *signed* cookie is stored on the client as plaintext alongside /// A *signed* cookie is stored on the client as plaintext alongside
/// a signature such that the cookie may be viewed but not modified by the client. /// a signature such that the cookie may be viewed but not modified by the
/// client.
/// ///
/// A *private* cookie is stored on the client as encrypted text /// A *private* cookie is stored on the client as encrypted text
/// such that it may neither be viewed nor modified by the client. /// such that it may neither be viewed nor modified by the client.
/// ///
/// The constructors take a key as an argument. /// The constructors take a key as an argument.
/// This is the private key for cookie session - when this value is changed, all session data is lost. /// This is the private key for cookie session - when this value is changed,
/// The constructors will panic if the key is less than 32 bytes in length. /// all session data is lost. The constructors will panic if the key is less
/// than 32 bytes in length.
/// ///
/// ///
/// # Example /// # Example
@ -380,21 +385,24 @@ impl CookieSessionInner {
pub struct CookieSessionBackend(Rc<CookieSessionInner>); pub struct CookieSessionBackend(Rc<CookieSessionInner>);
impl CookieSessionBackend { impl CookieSessionBackend {
/// Construct new *signed* `CookieSessionBackend` instance. /// Construct new *signed* `CookieSessionBackend` instance.
/// ///
/// Panics if key length is less than 32 bytes. /// Panics if key length is less than 32 bytes.
pub fn signed(key: &[u8]) -> CookieSessionBackend { pub fn signed(key: &[u8]) -> CookieSessionBackend {
CookieSessionBackend( CookieSessionBackend(Rc::new(CookieSessionInner::new(
Rc::new(CookieSessionInner::new(key, CookieSecurity::Signed))) key,
CookieSecurity::Signed,
)))
} }
/// Construct new *private* `CookieSessionBackend` instance. /// Construct new *private* `CookieSessionBackend` instance.
/// ///
/// Panics if key length is less than 32 bytes. /// Panics if key length is less than 32 bytes.
pub fn private(key: &[u8]) -> CookieSessionBackend { pub fn private(key: &[u8]) -> CookieSessionBackend {
CookieSessionBackend( CookieSessionBackend(Rc::new(CookieSessionInner::new(
Rc::new(CookieSessionInner::new(key, CookieSecurity::Private))) key,
CookieSecurity::Private,
)))
} }
/// Sets the `path` field in the session cookie being built. /// Sets the `path` field in the session cookie being built.
@ -432,17 +440,15 @@ impl CookieSessionBackend {
} }
impl<S> SessionBackend<S> for CookieSessionBackend { impl<S> SessionBackend<S> for CookieSessionBackend {
type Session = CookieSession; type Session = CookieSession;
type ReadFuture = FutureResult<CookieSession, Error>; type ReadFuture = FutureResult<CookieSession, Error>;
fn from_request(&self, req: &mut HttpRequest<S>) -> Self::ReadFuture { fn from_request(&self, req: &mut HttpRequest<S>) -> Self::ReadFuture {
let state = self.0.load(req); let state = self.0.load(req);
FutOk( FutOk(CookieSession {
CookieSession { changed: false,
changed: false, inner: Rc::clone(&self.0),
inner: Rc::clone(&self.0), state,
state, })
})
} }
} }

View File

@ -1,18 +1,18 @@
//! Multipart requests support //! Multipart requests support
use std::{cmp, fmt};
use std::rc::Rc;
use std::cell::RefCell; use std::cell::RefCell;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc;
use std::{cmp, fmt};
use mime;
use httparse;
use bytes::Bytes; use bytes::Bytes;
use futures::task::{current as current_task, Task};
use futures::{Async, Poll, Stream};
use http::HttpTryFrom; use http::HttpTryFrom;
use http::header::{self, HeaderMap, HeaderName, HeaderValue}; use http::header::{self, HeaderMap, HeaderName, HeaderValue};
use futures::{Async, Stream, Poll}; use httparse;
use futures::task::{Task, current as current_task}; use mime;
use error::{ParseError, PayloadError, MultipartError}; use error::{MultipartError, ParseError, PayloadError};
use payload::PayloadHelper; use payload::PayloadHelper;
const MAX_HEADERS: usize = 32; const MAX_HEADERS: usize = 32;
@ -85,33 +85,36 @@ impl Multipart<()> {
} }
} }
impl<S> Multipart<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Multipart<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create multipart instance for boundary. /// Create multipart instance for boundary.
pub fn new(boundary: Result<String, MultipartError>, stream: S) -> Multipart<S> { pub fn new(boundary: Result<String, MultipartError>, stream: S) -> Multipart<S> {
match boundary { match boundary {
Ok(boundary) => Multipart { Ok(boundary) => Multipart {
error: None, error: None,
safety: Safety::new(), safety: Safety::new(),
inner: Some(Rc::new(RefCell::new( inner: Some(Rc::new(RefCell::new(InnerMultipart {
InnerMultipart { boundary,
boundary, payload: PayloadRef::new(PayloadHelper::new(stream)),
payload: PayloadRef::new(PayloadHelper::new(stream)), state: InnerState::FirstBoundary,
state: InnerState::FirstBoundary, item: InnerMultipartItem::None,
item: InnerMultipartItem::None, }))),
}))) },
Err(err) => Multipart {
error: Some(err),
safety: Safety::new(),
inner: None,
}, },
Err(err) =>
Multipart {
error: Some(err),
safety: Safety::new(),
inner: None,
}
} }
} }
} }
impl<S> Stream for Multipart<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Stream for Multipart<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = MultipartItem<S>; type Item = MultipartItem<S>;
type Error = MultipartError; type Error = MultipartError;
@ -119,17 +122,22 @@ impl<S> Stream for Multipart<S> where S: Stream<Item=Bytes, Error=PayloadError>
if let Some(err) = self.error.take() { if let Some(err) = self.error.take() {
Err(err) Err(err)
} else if self.safety.current() { } else if self.safety.current() {
self.inner.as_mut().unwrap().borrow_mut().poll(&self.safety) self.inner
.as_mut()
.unwrap()
.borrow_mut()
.poll(&self.safety)
} else { } else {
Ok(Async::NotReady) Ok(Async::NotReady)
} }
} }
} }
impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> InnerMultipart<S>
where
fn read_headers(payload: &mut PayloadHelper<S>) -> Poll<HeaderMap, MultipartError> S: Stream<Item = Bytes, Error = PayloadError>,
{ {
fn read_headers(payload: &mut PayloadHelper<S>) -> Poll<HeaderMap, MultipartError> {
match payload.read_until(b"\r\n\r\n")? { match payload.read_until(b"\r\n\r\n")? {
Async::NotReady => Ok(Async::NotReady), Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete), Async::Ready(None) => Err(MultipartError::Incomplete),
@ -144,10 +152,10 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
if let Ok(value) = HeaderValue::try_from(h.value) { if let Ok(value) = HeaderValue::try_from(h.value) {
headers.append(name, value); headers.append(name, value);
} else { } else {
return Err(ParseError::Header.into()) return Err(ParseError::Header.into());
} }
} else { } else {
return Err(ParseError::Header.into()) return Err(ParseError::Header.into());
} }
} }
Ok(Async::Ready(headers)) Ok(Async::Ready(headers))
@ -159,23 +167,21 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} }
} }
fn read_boundary(payload: &mut PayloadHelper<S>, boundary: &str) fn read_boundary(
-> Poll<bool, MultipartError> payload: &mut PayloadHelper<S>, boundary: &str
{ ) -> Poll<bool, MultipartError> {
// TODO: need to read epilogue // TODO: need to read epilogue
match payload.readline()? { match payload.readline()? {
Async::NotReady => Ok(Async::NotReady), Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete), Async::Ready(None) => Err(MultipartError::Incomplete),
Async::Ready(Some(chunk)) => { Async::Ready(Some(chunk)) => {
if chunk.len() == boundary.len() + 4 && if chunk.len() == boundary.len() + 4 && &chunk[..2] == b"--"
&chunk[..2] == b"--" && && &chunk[2..boundary.len() + 2] == boundary.as_bytes()
&chunk[2..boundary.len()+2] == boundary.as_bytes()
{ {
Ok(Async::Ready(false)) Ok(Async::Ready(false))
} else if chunk.len() == boundary.len() + 6 && } else if chunk.len() == boundary.len() + 6 && &chunk[..2] == b"--"
&chunk[..2] == b"--" && && &chunk[2..boundary.len() + 2] == boundary.as_bytes()
&chunk[2..boundary.len()+2] == boundary.as_bytes() && && &chunk[boundary.len() + 2..boundary.len() + 4] == b"--"
&chunk[boundary.len()+2..boundary.len()+4] == b"--"
{ {
Ok(Async::Ready(true)) Ok(Async::Ready(true))
} else { } else {
@ -185,9 +191,9 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} }
} }
fn skip_until_boundary(payload: &mut PayloadHelper<S>, boundary: &str) fn skip_until_boundary(
-> Poll<bool, MultipartError> payload: &mut PayloadHelper<S>, boundary: &str
{ ) -> Poll<bool, MultipartError> {
let mut eof = false; let mut eof = false;
loop { loop {
match payload.readline()? { match payload.readline()? {
@ -197,22 +203,25 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
//% (self._boundary)) //% (self._boundary))
} }
if chunk.len() < boundary.len() { if chunk.len() < boundary.len() {
continue continue;
} }
if &chunk[..2] == b"--" && &chunk[2..chunk.len()-2] == boundary.as_bytes() { if &chunk[..2] == b"--"
&& &chunk[2..chunk.len() - 2] == boundary.as_bytes()
{
break; break;
} else { } else {
if chunk.len() < boundary.len() + 2{ if chunk.len() < boundary.len() + 2 {
continue continue;
} }
let b: &[u8] = boundary.as_ref(); let b: &[u8] = boundary.as_ref();
if &chunk[..boundary.len()] == b && if &chunk[..boundary.len()] == b
&chunk[boundary.len()..boundary.len()+2] == b"--" { && &chunk[boundary.len()..boundary.len() + 2] == b"--"
eof = true; {
break; eof = true;
} break;
}
} }
}, }
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
Async::Ready(None) => return Err(MultipartError::Incomplete), Async::Ready(None) => return Err(MultipartError::Incomplete),
} }
@ -220,7 +229,9 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Ok(Async::Ready(eof)) Ok(Async::Ready(eof))
} }
fn poll(&mut self, safety: &Safety) -> Poll<Option<MultipartItem<S>>, MultipartError> { fn poll(
&mut self, safety: &Safety
) -> Poll<Option<MultipartItem<S>>, MultipartError> {
if self.state == InnerState::Eof { if self.state == InnerState::Eof {
Ok(Async::Ready(None)) Ok(Async::Ready(None))
} else { } else {
@ -236,14 +247,14 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Async::Ready(Some(_)) => continue, Async::Ready(Some(_)) => continue,
Async::Ready(None) => true, Async::Ready(None) => true,
} }
}, }
InnerMultipartItem::Multipart(ref mut multipart) => { InnerMultipartItem::Multipart(ref mut multipart) => {
match multipart.borrow_mut().poll(safety)? { match multipart.borrow_mut().poll(safety)? {
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(_)) => continue, Async::Ready(Some(_)) => continue,
Async::Ready(None) => true, Async::Ready(None) => true,
} }
}, }
_ => false, _ => false,
}; };
if stop { if stop {
@ -259,7 +270,10 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
match self.state { match self.state {
// read until first boundary // read until first boundary
InnerState::FirstBoundary => { InnerState::FirstBoundary => {
match InnerMultipart::skip_until_boundary(payload, &self.boundary)? { match InnerMultipart::skip_until_boundary(
payload,
&self.boundary,
)? {
Async::Ready(eof) => { Async::Ready(eof) => {
if eof { if eof {
self.state = InnerState::Eof; self.state = InnerState::Eof;
@ -267,10 +281,10 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} else { } else {
self.state = InnerState::Headers; self.state = InnerState::Headers;
} }
}, }
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
} }
}, }
// read boundary // read boundary
InnerState::Boundary => { InnerState::Boundary => {
match InnerMultipart::read_boundary(payload, &self.boundary)? { match InnerMultipart::read_boundary(payload, &self.boundary)? {
@ -290,18 +304,19 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// read field headers for next field // read field headers for next field
if self.state == InnerState::Headers { if self.state == InnerState::Headers {
if let Async::Ready(headers) = InnerMultipart::read_headers(payload)? { if let Async::Ready(headers) = InnerMultipart::read_headers(payload)?
{
self.state = InnerState::Boundary; self.state = InnerState::Boundary;
headers headers
} else { } else {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
} else { } else {
unreachable!() unreachable!()
} }
} else { } else {
debug!("NotReady: field is in flight"); debug!("NotReady: field is in flight");
return Ok(Async::NotReady) return Ok(Async::NotReady);
}; };
// content type // content type
@ -319,32 +334,37 @@ impl<S> InnerMultipart<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// nested multipart stream // nested multipart stream
if mt.type_() == mime::MULTIPART { if mt.type_() == mime::MULTIPART {
let inner = if let Some(boundary) = mt.get_param(mime::BOUNDARY) { let inner = if let Some(boundary) = mt.get_param(mime::BOUNDARY) {
Rc::new(RefCell::new( Rc::new(RefCell::new(InnerMultipart {
InnerMultipart { payload: self.payload.clone(),
payload: self.payload.clone(), boundary: boundary.as_str().to_owned(),
boundary: boundary.as_str().to_owned(), state: InnerState::FirstBoundary,
state: InnerState::FirstBoundary, item: InnerMultipartItem::None,
item: InnerMultipartItem::None, }))
}))
} else { } else {
return Err(MultipartError::Boundary) return Err(MultipartError::Boundary);
}; };
self.item = InnerMultipartItem::Multipart(Rc::clone(&inner)); self.item = InnerMultipartItem::Multipart(Rc::clone(&inner));
Ok(Async::Ready(Some( Ok(Async::Ready(Some(MultipartItem::Nested(Multipart {
MultipartItem::Nested( safety: safety.clone(),
Multipart{safety: safety.clone(), error: None,
error: None, inner: Some(inner),
inner: Some(inner)})))) }))))
} else { } else {
let field = Rc::new(RefCell::new(InnerField::new( let field = Rc::new(RefCell::new(InnerField::new(
self.payload.clone(), self.boundary.clone(), &headers)?)); self.payload.clone(),
self.boundary.clone(),
&headers,
)?));
self.item = InnerMultipartItem::Field(Rc::clone(&field)); self.item = InnerMultipartItem::Field(Rc::clone(&field));
Ok(Async::Ready(Some( Ok(Async::Ready(Some(MultipartItem::Field(Field::new(
MultipartItem::Field( safety.clone(),
Field::new(safety.clone(), headers, mt, field))))) headers,
mt,
field,
)))))
} }
} }
} }
@ -365,11 +385,20 @@ pub struct Field<S> {
safety: Safety, safety: Safety,
} }
impl<S> Field<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Field<S>
where
fn new(safety: Safety, headers: HeaderMap, S: Stream<Item = Bytes, Error = PayloadError>,
ct: mime::Mime, inner: Rc<RefCell<InnerField<S>>>) -> Self { {
Field {ct, headers, inner, safety} fn new(
safety: Safety, headers: HeaderMap, ct: mime::Mime,
inner: Rc<RefCell<InnerField<S>>>,
) -> Self {
Field {
ct,
headers,
inner,
safety,
}
} }
pub fn headers(&self) -> &HeaderMap { pub fn headers(&self) -> &HeaderMap {
@ -381,7 +410,10 @@ impl<S> Field<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} }
} }
impl<S> Stream for Field<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Stream for Field<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = Bytes; type Item = Bytes;
type Error = MultipartError; type Error = MultipartError;
@ -413,20 +445,22 @@ struct InnerField<S> {
length: Option<u64>, length: Option<u64>,
} }
impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> InnerField<S>
where
fn new(payload: PayloadRef<S>, boundary: String, headers: &HeaderMap) S: Stream<Item = Bytes, Error = PayloadError>,
-> Result<InnerField<S>, PayloadError> {
{ fn new(
payload: PayloadRef<S>, boundary: String, headers: &HeaderMap
) -> Result<InnerField<S>, PayloadError> {
let len = if let Some(len) = headers.get(header::CONTENT_LENGTH) { let len = if let Some(len) = headers.get(header::CONTENT_LENGTH) {
if let Ok(s) = len.to_str() { if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() { if let Ok(len) = s.parse::<u64>() {
Some(len) Some(len)
} else { } else {
return Err(PayloadError::Incomplete) return Err(PayloadError::Incomplete);
} }
} else { } else {
return Err(PayloadError::Incomplete) return Err(PayloadError::Incomplete);
} }
} else { } else {
None None
@ -436,14 +470,15 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
boundary, boundary,
payload: Some(payload), payload: Some(payload),
eof: false, eof: false,
length: len }) length: len,
})
} }
/// Reads body part content chunk of the specified size. /// Reads body part content chunk of the specified size.
/// The body part must has `Content-Length` header with proper value. /// The body part must has `Content-Length` header with proper value.
fn read_len(payload: &mut PayloadHelper<S>, size: &mut u64) fn read_len(
-> Poll<Option<Bytes>, MultipartError> payload: &mut PayloadHelper<S>, size: &mut u64
{ ) -> Poll<Option<Bytes>, MultipartError> {
if *size == 0 { if *size == 0 {
Ok(Async::Ready(None)) Ok(Async::Ready(None))
} else { } else {
@ -458,17 +493,17 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
payload.unread_data(chunk); payload.unread_data(chunk);
} }
Ok(Async::Ready(Some(ch))) Ok(Async::Ready(Some(ch)))
}, }
Err(err) => Err(err.into()) Err(err) => Err(err.into()),
} }
} }
} }
/// Reads content chunk of body part with unknown length. /// Reads content chunk of body part with unknown length.
/// The `Content-Length` header for body part is not necessary. /// The `Content-Length` header for body part is not necessary.
fn read_stream(payload: &mut PayloadHelper<S>, boundary: &str) fn read_stream(
-> Poll<Option<Bytes>, MultipartError> payload: &mut PayloadHelper<S>, boundary: &str
{ ) -> Poll<Option<Bytes>, MultipartError> {
match payload.read_until(b"\r")? { match payload.read_until(b"\r")? {
Async::NotReady => Ok(Async::NotReady), Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete), Async::Ready(None) => Err(MultipartError::Incomplete),
@ -479,8 +514,8 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
Async::NotReady => Ok(Async::NotReady), Async::NotReady => Ok(Async::NotReady),
Async::Ready(None) => Err(MultipartError::Incomplete), Async::Ready(None) => Err(MultipartError::Incomplete),
Async::Ready(Some(chunk)) => { Async::Ready(Some(chunk)) => {
if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--" && if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--"
&chunk[4..] == boundary.as_bytes() && &chunk[4..] == boundary.as_bytes()
{ {
payload.unread_data(chunk); payload.unread_data(chunk);
Ok(Async::Ready(None)) Ok(Async::Ready(None))
@ -501,7 +536,7 @@ impl<S> InnerField<S> where S: Stream<Item=Bytes, Error=PayloadError> {
fn poll(&mut self, s: &Safety) -> Poll<Option<Bytes>, MultipartError> { fn poll(&mut self, s: &Safety) -> Poll<Option<Bytes>, MultipartError> {
if self.payload.is_none() { if self.payload.is_none() {
return Ok(Async::Ready(None)) return Ok(Async::Ready(None));
} }
let result = if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) { let result = if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) {
@ -543,7 +578,10 @@ struct PayloadRef<S> {
payload: Rc<PayloadHelper<S>>, payload: Rc<PayloadHelper<S>>,
} }
impl<S> PayloadRef<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> PayloadRef<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
fn new(payload: PayloadHelper<S>) -> PayloadRef<S> { fn new(payload: PayloadHelper<S>) -> PayloadRef<S> {
PayloadRef { PayloadRef {
payload: Rc::new(payload), payload: Rc::new(payload),
@ -551,11 +589,12 @@ impl<S> PayloadRef<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} }
fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut PayloadHelper<S>> fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut PayloadHelper<S>>
where 'a: 'b where
'a: 'b,
{ {
if s.current() { if s.current() {
let payload: &mut PayloadHelper<S> = unsafe { let payload: &mut PayloadHelper<S> =
&mut *(self.payload.as_ref() as *const _ as *mut _)}; unsafe { &mut *(self.payload.as_ref() as *const _ as *mut _) };
Some(payload) Some(payload)
} else { } else {
None None
@ -571,8 +610,9 @@ impl<S> Clone for PayloadRef<S> {
} }
} }
/// Counter. It tracks of number of clones of payloads and give access to payload only /// Counter. It tracks of number of clones of payloads and give access to
/// to top most task panics if Safety get destroyed and it not top most task. /// payload only to top most task panics if Safety get destroyed and it not top
/// most task.
#[derive(Debug)] #[derive(Debug)]
struct Safety { struct Safety {
task: Option<Task>, task: Option<Task>,
@ -593,7 +633,6 @@ impl Safety {
fn current(&self) -> bool { fn current(&self) -> bool {
Rc::strong_count(&self.payload) == self.level Rc::strong_count(&self.payload) == self.level
} }
} }
impl Clone for Safety { impl Clone for Safety {
@ -624,8 +663,8 @@ mod tests {
use super::*; use super::*;
use bytes::Bytes; use bytes::Bytes;
use futures::future::{lazy, result}; use futures::future::{lazy, result};
use tokio_core::reactor::Core;
use payload::{Payload, PayloadWriter}; use payload::{Payload, PayloadWriter};
use tokio_core::reactor::Core;
#[test] #[test]
fn test_boundary() { fn test_boundary() {
@ -636,8 +675,10 @@ mod tests {
} }
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, headers.insert(
header::HeaderValue::from_static("test")); header::CONTENT_TYPE,
header::HeaderValue::from_static("test"),
);
match Multipart::boundary(&headers) { match Multipart::boundary(&headers) {
Err(MultipartError::ParseContentType) => (), Err(MultipartError::ParseContentType) => (),
@ -647,7 +688,8 @@ mod tests {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("multipart/mixed")); header::HeaderValue::from_static("multipart/mixed"),
);
match Multipart::boundary(&headers) { match Multipart::boundary(&headers) {
Err(MultipartError::Boundary) => (), Err(MultipartError::Boundary) => (),
_ => unreachable!("should not happen"), _ => unreachable!("should not happen"),
@ -657,18 +699,24 @@ mod tests {
headers.insert( headers.insert(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static( header::HeaderValue::from_static(
"multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"")); "multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"",
),
);
assert_eq!(Multipart::boundary(&headers).unwrap(), assert_eq!(
"5c02368e880e436dab70ed54e1c58209"); Multipart::boundary(&headers).unwrap(),
"5c02368e880e436dab70ed54e1c58209"
);
} }
#[test] #[test]
fn test_multipart() { fn test_multipart() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (mut sender, payload) = Payload::new(false); .unwrap()
.run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let bytes = Bytes::from( let bytes = Bytes::from(
"testasdadsad\r\n\ "testasdadsad\r\n\
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
@ -677,63 +725,64 @@ mod tests {
Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
data\r\n\ data\r\n\
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n"); --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n");
sender.feed_data(bytes); sender.feed_data(bytes);
let mut multipart = Multipart::new( let mut multipart = Multipart::new(
Ok("abbc761f78ff4d7cb7573b5a23f96ef0".to_owned()), payload); Ok("abbc761f78ff4d7cb7573b5a23f96ef0".to_owned()),
match multipart.poll() { payload,
Ok(Async::Ready(Some(item))) => { );
match item { match multipart.poll() {
Ok(Async::Ready(Some(item))) => match item {
MultipartItem::Field(mut field) => { MultipartItem::Field(mut field) => {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() { match field.poll() {
Ok(Async::Ready(Some(chunk))) => Ok(Async::Ready(Some(chunk))) => {
assert_eq!(chunk, "test"), assert_eq!(chunk, "test")
_ => unreachable!() }
_ => unreachable!(),
} }
match field.poll() { match field.poll() {
Ok(Async::Ready(None)) => (), Ok(Async::Ready(None)) => (),
_ => unreachable!() _ => unreachable!(),
} }
}, }
_ => unreachable!() _ => unreachable!(),
} },
_ => unreachable!(),
} }
_ => unreachable!()
}
match multipart.poll() { match multipart.poll() {
Ok(Async::Ready(Some(item))) => { Ok(Async::Ready(Some(item))) => match item {
match item {
MultipartItem::Field(mut field) => { MultipartItem::Field(mut field) => {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() { match field.poll() {
Ok(Async::Ready(Some(chunk))) => Ok(Async::Ready(Some(chunk))) => {
assert_eq!(chunk, "data"), assert_eq!(chunk, "data")
_ => unreachable!() }
_ => unreachable!(),
} }
match field.poll() { match field.poll() {
Ok(Async::Ready(None)) => (), Ok(Async::Ready(None)) => (),
_ => unreachable!() _ => unreachable!(),
} }
}, }
_ => unreachable!() _ => unreachable!(),
} },
_ => unreachable!(),
} }
_ => unreachable!()
}
match multipart.poll() { match multipart.poll() {
Ok(Async::Ready(None)) => (), Ok(Async::Ready(None)) => (),
_ => unreachable!() _ => unreachable!(),
} }
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
} }

View File

@ -1,16 +1,16 @@
use std;
use std::ops::Index;
use std::path::PathBuf;
use std::str::FromStr;
use std::slice::Iter;
use std::borrow::Cow;
use http::StatusCode; use http::StatusCode;
use smallvec::SmallVec; use smallvec::SmallVec;
use std;
use std::borrow::Cow;
use std::ops::Index;
use std::path::PathBuf;
use std::slice::Iter;
use std::str::FromStr;
use error::{ResponseError, UriSegmentError, InternalError}; use error::{InternalError, ResponseError, UriSegmentError};
/// A trait to abstract the idea of creating a new instance of a type from a
/// A trait to abstract the idea of creating a new instance of a type from a path parameter. /// path parameter.
pub trait FromParam: Sized { pub trait FromParam: Sized {
/// The associated error which can be returned from parsing. /// The associated error which can be returned from parsing.
type Err: ResponseError; type Err: ResponseError;
@ -26,7 +26,6 @@ pub trait FromParam: Sized {
pub struct Params<'a>(SmallVec<[(Cow<'a, str>, Cow<'a, str>); 3]>); pub struct Params<'a>(SmallVec<[(Cow<'a, str>, Cow<'a, str>); 3]>);
impl<'a> Params<'a> { impl<'a> Params<'a> {
pub(crate) fn new() -> Params<'a> { pub(crate) fn new() -> Params<'a> {
Params(SmallVec::new()) Params(SmallVec::new())
} }
@ -36,7 +35,9 @@ impl<'a> Params<'a> {
} }
pub(crate) fn add<N, V>(&mut self, name: N, value: V) pub(crate) fn add<N, V>(&mut self, name: N, value: V)
where N: Into<Cow<'a, str>>, V: Into<Cow<'a, str>>, where
N: Into<Cow<'a, str>>,
V: Into<Cow<'a, str>>,
{ {
self.0.push((name.into(), value.into())); self.0.push((name.into(), value.into()));
} }
@ -55,7 +56,7 @@ impl<'a> Params<'a> {
pub fn get(&'a self, key: &str) -> Option<&'a str> { pub fn get(&'a self, key: &str) -> Option<&'a str> {
for item in self.0.iter() { for item in self.0.iter() {
if key == item.0 { if key == item.0 {
return Some(item.1.as_ref()) return Some(item.1.as_ref());
} }
} }
None None
@ -63,7 +64,8 @@ impl<'a> Params<'a> {
/// Get matched `FromParam` compatible parameter by name. /// Get matched `FromParam` compatible parameter by name.
/// ///
/// If keyed parameter is not available empty string is used as default value. /// If keyed parameter is not available empty string is used as default
/// value.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_web; /// # extern crate actix_web;
@ -74,8 +76,7 @@ impl<'a> Params<'a> {
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
pub fn query<T: FromParam>(&'a self, key: &str) -> Result<T, <T as FromParam>::Err> pub fn query<T: FromParam>(&'a self, key: &str) -> Result<T, <T as FromParam>::Err> {
{
if let Some(s) = self.get(key) { if let Some(s) = self.get(key) {
T::from_param(s) T::from_param(s)
} else { } else {
@ -93,7 +94,8 @@ impl<'a, 'b, 'c: 'a> Index<&'b str> for &'c Params<'a> {
type Output = str; type Output = str;
fn index(&self, name: &'b str) -> &str { fn index(&self, name: &'b str) -> &str {
self.get(name).expect("Value for parameter is not available") self.get(name)
.expect("Value for parameter is not available")
} }
} }
@ -118,9 +120,9 @@ impl<'a, 'c: 'a> Index<usize> for &'c Params<'a> {
/// * On Windows, decoded segment contains any of: '\' /// * On Windows, decoded segment contains any of: '\'
/// * Percent-encoding results in invalid UTF8. /// * Percent-encoding results in invalid UTF8.
/// ///
/// As a result of these conditions, a `PathBuf` parsed from request path parameter is /// As a result of these conditions, a `PathBuf` parsed from request path
/// safe to interpolate within, or use as a suffix of, a path without additional /// parameter is safe to interpolate within, or use as a suffix of, a path
/// checks. /// without additional checks.
impl FromParam for PathBuf { impl FromParam for PathBuf {
type Err = UriSegmentError; type Err = UriSegmentError;
@ -130,19 +132,19 @@ impl FromParam for PathBuf {
if segment == ".." { if segment == ".." {
buf.pop(); buf.pop();
} else if segment.starts_with('.') { } else if segment.starts_with('.') {
return Err(UriSegmentError::BadStart('.')) return Err(UriSegmentError::BadStart('.'));
} else if segment.starts_with('*') { } else if segment.starts_with('*') {
return Err(UriSegmentError::BadStart('*')) return Err(UriSegmentError::BadStart('*'));
} else if segment.ends_with(':') { } else if segment.ends_with(':') {
return Err(UriSegmentError::BadEnd(':')) return Err(UriSegmentError::BadEnd(':'));
} else if segment.ends_with('>') { } else if segment.ends_with('>') {
return Err(UriSegmentError::BadEnd('>')) return Err(UriSegmentError::BadEnd('>'));
} else if segment.ends_with('<') { } else if segment.ends_with('<') {
return Err(UriSegmentError::BadEnd('<')) return Err(UriSegmentError::BadEnd('<'));
} else if segment.is_empty() { } else if segment.is_empty() {
continue continue;
} else if cfg!(windows) && segment.contains('\\') { } else if cfg!(windows) && segment.contains('\\') {
return Err(UriSegmentError::BadChar('\\')) return Err(UriSegmentError::BadChar('\\'));
} else { } else {
buf.push(segment) buf.push(segment)
} }
@ -162,7 +164,7 @@ macro_rules! FROM_STR {
.map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST)) .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST))
} }
} }
} };
} }
FROM_STR!(u8); FROM_STR!(u8);
@ -192,14 +194,33 @@ mod tests {
#[test] #[test]
fn test_path_buf() { fn test_path_buf() {
assert_eq!(PathBuf::from_param("/test/.tt"), Err(UriSegmentError::BadStart('.'))); assert_eq!(
assert_eq!(PathBuf::from_param("/test/*tt"), Err(UriSegmentError::BadStart('*'))); PathBuf::from_param("/test/.tt"),
assert_eq!(PathBuf::from_param("/test/tt:"), Err(UriSegmentError::BadEnd(':'))); Err(UriSegmentError::BadStart('.'))
assert_eq!(PathBuf::from_param("/test/tt<"), Err(UriSegmentError::BadEnd('<'))); );
assert_eq!(PathBuf::from_param("/test/tt>"), Err(UriSegmentError::BadEnd('>'))); assert_eq!(
assert_eq!(PathBuf::from_param("/seg1/seg2/"), PathBuf::from_param("/test/*tt"),
Ok(PathBuf::from_iter(vec!["seg1", "seg2"]))); Err(UriSegmentError::BadStart('*'))
assert_eq!(PathBuf::from_param("/seg1/../seg2/"), );
Ok(PathBuf::from_iter(vec!["seg2"]))); assert_eq!(
PathBuf::from_param("/test/tt:"),
Err(UriSegmentError::BadEnd(':'))
);
assert_eq!(
PathBuf::from_param("/test/tt<"),
Err(UriSegmentError::BadEnd('<'))
);
assert_eq!(
PathBuf::from_param("/test/tt>"),
Err(UriSegmentError::BadEnd('>'))
);
assert_eq!(
PathBuf::from_param("/seg1/seg2/"),
Ok(PathBuf::from_iter(vec!["seg1", "seg2"]))
);
assert_eq!(
PathBuf::from_param("/seg1/../seg2/"),
Ok(PathBuf::from_iter(vec!["seg2"]))
);
} }
} }

View File

@ -1,18 +1,17 @@
//! Payload stream //! Payload stream
use std::cmp;
use std::rc::{Rc, Weak};
use std::cell::RefCell;
use std::collections::VecDeque;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::task::{current as current_task, Task};
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use futures::task::{Task, current as current_task}; use std::cell::RefCell;
use std::cmp;
use std::collections::VecDeque;
use std::rc::{Rc, Weak};
use error::PayloadError; use error::PayloadError;
/// max buffer size 32k /// max buffer size 32k
pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub(crate) enum PayloadStatus { pub(crate) enum PayloadStatus {
Read, Read,
@ -22,9 +21,9 @@ pub(crate) enum PayloadStatus {
/// Buffered stream of bytes chunks /// Buffered stream of bytes chunks
/// ///
/// Payload stores chunks in a vector. First chunk can be received with `.readany()` method. /// Payload stores chunks in a vector. First chunk can be received with
/// Payload stream is not thread safe. Payload does not notify current task when /// `.readany()` method. Payload stream is not thread safe. Payload does not
/// new data is available. /// notify current task when new data is available.
/// ///
/// Payload stream can be used as `HttpResponse` body stream. /// Payload stream can be used as `HttpResponse` body stream.
#[derive(Debug)] #[derive(Debug)]
@ -33,10 +32,10 @@ pub struct Payload {
} }
impl Payload { impl Payload {
/// Create payload stream. /// Create payload stream.
/// ///
/// This method construct two objects responsible for bytes stream generation. /// This method construct two objects responsible for bytes stream
/// generation.
/// ///
/// * `PayloadSender` - *Sender* side of the stream /// * `PayloadSender` - *Sender* side of the stream
/// ///
@ -44,13 +43,20 @@ impl Payload {
pub fn new(eof: bool) -> (PayloadSender, Payload) { pub fn new(eof: bool) -> (PayloadSender, Payload) {
let shared = Rc::new(RefCell::new(Inner::new(eof))); let shared = Rc::new(RefCell::new(Inner::new(eof)));
(PayloadSender{inner: Rc::downgrade(&shared)}, Payload{inner: shared}) (
PayloadSender {
inner: Rc::downgrade(&shared),
},
Payload { inner: shared },
)
} }
/// Create empty payload /// Create empty payload
#[doc(hidden)] #[doc(hidden)]
pub fn empty() -> Payload { pub fn empty() -> Payload {
Payload{inner: Rc::new(RefCell::new(Inner::new(true)))} Payload {
inner: Rc::new(RefCell::new(Inner::new(true))),
}
} }
/// Indicates EOF of payload /// Indicates EOF of payload
@ -103,13 +109,14 @@ impl Stream for Payload {
impl Clone for Payload { impl Clone for Payload {
fn clone(&self) -> Payload { fn clone(&self) -> Payload {
Payload{inner: Rc::clone(&self.inner)} Payload {
inner: Rc::clone(&self.inner),
}
} }
} }
/// Payload writer interface. /// Payload writer interface.
pub(crate) trait PayloadWriter { pub(crate) trait PayloadWriter {
/// Set stream error. /// Set stream error.
fn set_error(&mut self, err: PayloadError); fn set_error(&mut self, err: PayloadError);
@ -129,7 +136,6 @@ pub struct PayloadSender {
} }
impl PayloadWriter for PayloadSender { impl PayloadWriter for PayloadSender {
#[inline] #[inline]
fn set_error(&mut self, err: PayloadError) { fn set_error(&mut self, err: PayloadError) {
if let Some(shared) = self.inner.upgrade() { if let Some(shared) = self.inner.upgrade() {
@ -186,7 +192,6 @@ struct Inner {
} }
impl Inner { impl Inner {
fn new(eof: bool) -> Self { fn new(eof: bool) -> Self {
Inner { Inner {
eof, eof,
@ -292,8 +297,10 @@ pub struct PayloadHelper<S> {
stream: S, stream: S,
} }
impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> PayloadHelper<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
pub fn new(stream: S) -> Self { pub fn new(stream: S) -> Self {
PayloadHelper { PayloadHelper {
len: 0, len: 0,
@ -309,16 +316,14 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
#[inline] #[inline]
fn poll_stream(&mut self) -> Poll<bool, PayloadError> { fn poll_stream(&mut self) -> Poll<bool, PayloadError> {
self.stream.poll().map(|res| { self.stream.poll().map(|res| match res {
match res { Async::Ready(Some(data)) => {
Async::Ready(Some(data)) => { self.len += data.len();
self.len += data.len(); self.items.push_back(data);
self.items.push_back(data); Async::Ready(true)
Async::Ready(true)
},
Async::Ready(None) => Async::Ready(false),
Async::NotReady => Async::NotReady,
} }
Async::Ready(None) => Async::Ready(false),
Async::NotReady => Async::NotReady,
}) })
} }
@ -373,11 +378,9 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
let buf = chunk.split_to(size); let buf = chunk.split_to(size);
self.items.push_front(chunk); self.items.push_front(chunk);
Ok(Async::Ready(Some(buf))) Ok(Async::Ready(Some(buf)))
} } else if size == chunk.len() {
else if size == chunk.len() {
Ok(Async::Ready(Some(chunk))) Ok(Async::Ready(Some(chunk)))
} } else {
else {
let mut buf = BytesMut::with_capacity(size); let mut buf = BytesMut::with_capacity(size);
buf.extend_from_slice(&chunk); buf.extend_from_slice(&chunk);
@ -408,7 +411,7 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
let mut len = 0; let mut len = 0;
while len < size { while len < size {
let mut chunk = self.items.pop_front().unwrap(); let mut chunk = self.items.pop_front().unwrap();
let rem = cmp::min(size-len, chunk.len()); let rem = cmp::min(size - len, chunk.len());
len += rem; len += rem;
if rem < chunk.len() { if rem < chunk.len() {
chunk.split_to(rem); chunk.split_to(rem);
@ -427,7 +430,7 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
buf.extend_from_slice(&chunk[..rem]); buf.extend_from_slice(&chunk[..rem]);
} }
if buf.len() == size { if buf.len() == size {
return Ok(Async::Ready(Some(buf))) return Ok(Async::Ready(Some(buf)));
} }
} }
} }
@ -454,8 +457,8 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
idx += 1; idx += 1;
if idx == line.len() { if idx == line.len() {
num = no; num = no;
offset = pos+1; offset = pos + 1;
length += pos+1; length += pos + 1;
found = true; found = true;
break; break;
} }
@ -483,7 +486,7 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} }
} }
self.len -= length; self.len -= length;
return Ok(Async::Ready(Some(buf.freeze()))) return Ok(Async::Ready(Some(buf.freeze())));
} }
} }
@ -505,171 +508,217 @@ impl<S> PayloadHelper<S> where S: Stream<Item=Bytes, Error=PayloadError> {
#[allow(dead_code)] #[allow(dead_code)]
pub fn remaining(&mut self) -> Bytes { pub fn remaining(&mut self) -> Bytes {
self.items.iter_mut() self.items
.iter_mut()
.fold(BytesMut::new(), |mut b, c| { .fold(BytesMut::new(), |mut b, c| {
b.extend_from_slice(c); b.extend_from_slice(c);
b b
}).freeze() })
.freeze()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::io;
use failure::Fail; use failure::Fail;
use futures::future::{lazy, result}; use futures::future::{lazy, result};
use std::io;
use tokio_core::reactor::Core; use tokio_core::reactor::Core;
#[test] #[test]
fn test_error() { fn test_error() {
let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into(); let err: PayloadError =
io::Error::new(io::ErrorKind::Other, "ParseError").into();
assert_eq!(format!("{}", err), "ParseError"); assert_eq!(format!("{}", err), "ParseError");
assert_eq!(format!("{}", err.cause().unwrap()), "ParseError"); assert_eq!(format!("{}", err.cause().unwrap()), "ParseError");
let err = PayloadError::Incomplete; let err = PayloadError::Incomplete;
assert_eq!(format!("{}", err), "A payload reached EOF, but is not complete."); assert_eq!(
format!("{}", err),
"A payload reached EOF, but is not complete."
);
} }
#[test] #[test]
fn test_basic() { fn test_basic() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (_, payload) = Payload::new(false); .unwrap()
let mut payload = PayloadHelper::new(payload); .run(lazy(|| {
let (_, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(payload.len, 0); assert_eq!(payload.len, 0);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
#[test] #[test]
fn test_eof() { fn test_eof() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (mut sender, payload) = Payload::new(false); .unwrap()
let mut payload = PayloadHelper::new(payload); .run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
sender.feed_data(Bytes::from("data")); sender.feed_data(Bytes::from("data"));
sender.feed_eof(); sender.feed_eof();
assert_eq!(Async::Ready(Some(Bytes::from("data"))), assert_eq!(
payload.readany().ok().unwrap()); Async::Ready(Some(Bytes::from("data"))),
assert_eq!(payload.len, 0); payload.readany().ok().unwrap()
assert_eq!(Async::Ready(None), payload.readany().ok().unwrap()); );
assert_eq!(payload.len, 0);
assert_eq!(Async::Ready(None), payload.readany().ok().unwrap());
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
#[test] #[test]
fn test_err() { fn test_err() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (mut sender, payload) = Payload::new(false); .unwrap()
let mut payload = PayloadHelper::new(payload); .run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
sender.set_error(PayloadError::Incomplete); sender.set_error(PayloadError::Incomplete);
payload.readany().err().unwrap(); payload.readany().err().unwrap();
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
#[test] #[test]
fn test_readany() { fn test_readany() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (mut sender, payload) = Payload::new(false); .unwrap()
let mut payload = PayloadHelper::new(payload); .run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
assert_eq!(Async::Ready(Some(Bytes::from("line1"))), assert_eq!(
payload.readany().ok().unwrap()); Async::Ready(Some(Bytes::from("line1"))),
assert_eq!(payload.len, 0); payload.readany().ok().unwrap()
);
assert_eq!(payload.len, 0);
assert_eq!(Async::Ready(Some(Bytes::from("line2"))), assert_eq!(
payload.readany().ok().unwrap()); Async::Ready(Some(Bytes::from("line2"))),
assert_eq!(payload.len, 0); payload.readany().ok().unwrap()
);
assert_eq!(payload.len, 0);
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
#[test] #[test]
fn test_readexactly() { fn test_readexactly() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (mut sender, payload) = Payload::new(false); .unwrap()
let mut payload = PayloadHelper::new(payload); .run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap()); assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap());
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
assert_eq!(Async::Ready(Some(Bytes::from_static(b"li"))), assert_eq!(
payload.read_exact(2).ok().unwrap()); Async::Ready(Some(Bytes::from_static(b"li"))),
assert_eq!(payload.len, 3); payload.read_exact(2).ok().unwrap()
);
assert_eq!(payload.len, 3);
assert_eq!(Async::Ready(Some(Bytes::from_static(b"ne1l"))), assert_eq!(
payload.read_exact(4).ok().unwrap()); Async::Ready(Some(Bytes::from_static(b"ne1l"))),
assert_eq!(payload.len, 4); payload.read_exact(4).ok().unwrap()
);
assert_eq!(payload.len, 4);
sender.set_error(PayloadError::Incomplete); sender.set_error(PayloadError::Incomplete);
payload.read_exact(10).err().unwrap(); payload.read_exact(10).err().unwrap();
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
#[test] #[test]
fn test_readuntil() { fn test_readuntil() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (mut sender, payload) = Payload::new(false); .unwrap()
let mut payload = PayloadHelper::new(payload); .run(lazy(|| {
let (mut sender, payload) = Payload::new(false);
let mut payload = PayloadHelper::new(payload);
assert_eq!(Async::NotReady, payload.read_until(b"ne").ok().unwrap()); assert_eq!(
Async::NotReady,
payload.read_until(b"ne").ok().unwrap()
);
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
assert_eq!(Async::Ready(Some(Bytes::from("line"))), assert_eq!(
payload.read_until(b"ne").ok().unwrap()); Async::Ready(Some(Bytes::from("line"))),
assert_eq!(payload.len, 1); payload.read_until(b"ne").ok().unwrap()
);
assert_eq!(payload.len, 1);
assert_eq!(Async::Ready(Some(Bytes::from("1line2"))), assert_eq!(
payload.read_until(b"2").ok().unwrap()); Async::Ready(Some(Bytes::from("1line2"))),
assert_eq!(payload.len, 0); payload.read_until(b"2").ok().unwrap()
);
assert_eq!(payload.len, 0);
sender.set_error(PayloadError::Incomplete); sender.set_error(PayloadError::Incomplete);
payload.read_until(b"b").err().unwrap(); payload.read_until(b"b").err().unwrap();
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
#[test] #[test]
fn test_unread_data() { fn test_unread_data() {
Core::new().unwrap().run(lazy(|| { Core::new()
let (_, mut payload) = Payload::new(false); .unwrap()
.run(lazy(|| {
let (_, mut payload) = Payload::new(false);
payload.unread_data(Bytes::from("data")); payload.unread_data(Bytes::from("data"));
assert!(!payload.is_empty()); assert!(!payload.is_empty());
assert_eq!(payload.len(), 4); assert_eq!(payload.len(), 4);
assert_eq!(Async::Ready(Some(Bytes::from("data"))), assert_eq!(
payload.poll().ok().unwrap()); Async::Ready(Some(Bytes::from("data"))),
payload.poll().ok().unwrap()
);
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
})).unwrap(); }))
.unwrap();
} }
} }

View File

@ -1,22 +1,22 @@
use std::{io, mem};
use std::rc::Rc;
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc;
use std::{io, mem};
use log::Level::Debug;
use futures::{Async, Poll, Future, Stream};
use futures::unsync::oneshot; use futures::unsync::oneshot;
use futures::{Async, Future, Poll, Stream};
use log::Level::Debug;
use application::Inner;
use body::{Body, BodyStream}; use body::{Body, BodyStream};
use context::{Frame, ActorHttpContext}; use context::{ActorHttpContext, Frame};
use error::Error; use error::Error;
use header::ContentEncoding;
use handler::{Reply, ReplyItem}; use handler::{Reply, ReplyItem};
use header::ContentEncoding;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::{Middleware, Finished, Started, Response}; use middleware::{Finished, Middleware, Response, Started};
use application::Inner; use server::{HttpHandlerTask, Writer, WriterState};
use server::{Writer, WriterState, HttpHandlerTask};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub(crate) enum HandlerType { pub(crate) enum HandlerType {
@ -26,7 +26,6 @@ pub(crate) enum HandlerType {
} }
pub(crate) trait PipelineHandler<S> { pub(crate) trait PipelineHandler<S> {
fn encoding(&self) -> ContentEncoding; fn encoding(&self) -> ContentEncoding;
fn handle(&mut self, req: HttpRequest<S>, htype: HandlerType) -> Reply; fn handle(&mut self, req: HttpRequest<S>, htype: HandlerType) -> Reply;
@ -46,7 +45,6 @@ enum PipelineState<S, H> {
} }
impl<S: 'static, H: PipelineHandler<S>> PipelineState<S, H> { impl<S: 'static, H: PipelineHandler<S>> PipelineState<S, H> {
fn is_response(&self) -> bool { fn is_response(&self) -> bool {
match *self { match *self {
PipelineState::Response(_) => true, PipelineState::Response(_) => true,
@ -61,10 +59,12 @@ impl<S: 'static, H: PipelineHandler<S>> PipelineState<S, H> {
PipelineState::RunMiddlewares(ref mut state) => state.poll(info), PipelineState::RunMiddlewares(ref mut state) => state.poll(info),
PipelineState::Finishing(ref mut state) => state.poll(info), PipelineState::Finishing(ref mut state) => state.poll(info),
PipelineState::Completed(ref mut state) => state.poll(info), PipelineState::Completed(ref mut state) => state.poll(info),
PipelineState::Response(_) | PipelineState::None | PipelineState::Error => None, PipelineState::Response(_) | PipelineState::None | PipelineState::Error => {
None
}
} }
} }
} }
struct PipelineInfo<S> { struct PipelineInfo<S> {
req: HttpRequest<S>, req: HttpRequest<S>,
@ -92,7 +92,9 @@ impl<S> PipelineInfo<S> {
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))]
fn req_mut(&self) -> &mut HttpRequest<S> { fn req_mut(&self) -> &mut HttpRequest<S> {
#[allow(mutable_transmutes)] #[allow(mutable_transmutes)]
unsafe{mem::transmute(&self.req)} unsafe {
mem::transmute(&self.req)
}
} }
fn poll_context(&mut self) -> Poll<(), Error> { fn poll_context(&mut self) -> Poll<(), Error> {
@ -109,18 +111,18 @@ impl<S> PipelineInfo<S> {
} }
impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> { impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> {
pub fn new(
pub fn new(req: HttpRequest<S>, req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>,
mws: Rc<Vec<Box<Middleware<S>>>>, handler: Rc<UnsafeCell<H>>, htype: HandlerType,
handler: Rc<UnsafeCell<H>>, htype: HandlerType) -> Pipeline<S, H> ) -> Pipeline<S, H> {
{
let mut info = PipelineInfo { let mut info = PipelineInfo {
req, mws, req,
mws,
count: 0, count: 0,
error: None, error: None,
context: None, context: None,
disconnected: None, disconnected: None,
encoding: unsafe{&*handler.get()}.encoding(), encoding: unsafe { &*handler.get() }.encoding(),
}; };
let state = StartMiddlewares::init(&mut info, handler, htype); let state = StartMiddlewares::init(&mut info, handler, htype);
@ -131,30 +133,33 @@ impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> {
impl Pipeline<(), Inner<()>> { impl Pipeline<(), Inner<()>> {
pub fn error<R: Into<HttpResponse>>(err: R) -> Box<HttpHandlerTask> { pub fn error<R: Into<HttpResponse>>(err: R) -> Box<HttpHandlerTask> {
Box::new(Pipeline::<(), Inner<()>>( Box::new(Pipeline::<(), Inner<()>>(
PipelineInfo::new(HttpRequest::default()), ProcessResponse::init(err.into()))) PipelineInfo::new(HttpRequest::default()),
ProcessResponse::init(err.into()),
))
} }
} }
impl<S: 'static, H> Pipeline<S, H> { impl<S: 'static, H> Pipeline<S, H> {
fn is_done(&self) -> bool { fn is_done(&self) -> bool {
match self.1 { match self.1 {
PipelineState::None | PipelineState::Error PipelineState::None
| PipelineState::Starting(_) | PipelineState::Handler(_) | PipelineState::Error
| PipelineState::RunMiddlewares(_) | PipelineState::Response(_) => true, | PipelineState::Starting(_)
| PipelineState::Handler(_)
| PipelineState::RunMiddlewares(_)
| PipelineState::Response(_) => true,
PipelineState::Finishing(_) | PipelineState::Completed(_) => false, PipelineState::Finishing(_) | PipelineState::Completed(_) => false,
} }
} }
} }
impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> { impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
fn disconnected(&mut self) { fn disconnected(&mut self) {
self.0.disconnected = Some(true); self.0.disconnected = Some(true);
} }
fn poll_io(&mut self, io: &mut Writer) -> Poll<bool, Error> { fn poll_io(&mut self, io: &mut Writer) -> Poll<bool, Error> {
let info: &mut PipelineInfo<_> = unsafe{ mem::transmute(&mut self.0) }; let info: &mut PipelineInfo<_> = unsafe { mem::transmute(&mut self.0) };
loop { loop {
if self.1.is_response() { if self.1.is_response() {
@ -164,9 +169,9 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
Ok(state) => { Ok(state) => {
self.1 = state; self.1 = state;
if let Some(error) = self.0.error.take() { if let Some(error) = self.0.error.take() {
return Err(error) return Err(error);
} else { } else {
return Ok(Async::Ready(self.is_done())) return Ok(Async::Ready(self.is_done()));
} }
} }
Err(state) => { Err(state) => {
@ -177,11 +182,10 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
} }
} }
match self.1 { match self.1 {
PipelineState::None => PipelineState::None => return Ok(Async::Ready(true)),
return Ok(Async::Ready(true)), PipelineState::Error => {
PipelineState::Error => return Err(io::Error::new(io::ErrorKind::Other, "Internal error").into())
return Err(io::Error::new( }
io::ErrorKind::Other, "Internal error").into()),
_ => (), _ => (),
} }
@ -193,7 +197,7 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
} }
fn poll(&mut self) -> Poll<(), Error> { fn poll(&mut self) -> Poll<(), Error> {
let info: &mut PipelineInfo<_> = unsafe{ mem::transmute(&mut self.0) }; let info: &mut PipelineInfo<_> = unsafe { mem::transmute(&mut self.0) };
loop { loop {
match self.1 { match self.1 {
@ -212,7 +216,7 @@ impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> {
} }
} }
type Fut = Box<Future<Item=Option<HttpResponse>, Error=Error>>; type Fut = Box<Future<Item = Option<HttpResponse>, Error = Error>>;
/// Middlewares start executor /// Middlewares start executor
struct StartMiddlewares<S, H> { struct StartMiddlewares<S, H> {
@ -223,41 +227,40 @@ struct StartMiddlewares<S, H> {
} }
impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> { impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
fn init(
fn init(info: &mut PipelineInfo<S>, hnd: Rc<UnsafeCell<H>>, htype: HandlerType) info: &mut PipelineInfo<S>, hnd: Rc<UnsafeCell<H>>, htype: HandlerType
-> PipelineState<S, H> ) -> PipelineState<S, H> {
{ // execute middlewares, we need this stage because middlewares could be
// execute middlewares, we need this stage because middlewares could be non-async // non-async and we can move to next state immediately
// and we can move to next state immediately
let len = info.mws.len() as u16; let len = info.mws.len() as u16;
loop { loop {
if info.count == len { if info.count == len {
let reply = unsafe{&mut *hnd.get()}.handle(info.req.clone(), htype); let reply = unsafe { &mut *hnd.get() }.handle(info.req.clone(), htype);
return WaitingResponse::init(info, reply) return WaitingResponse::init(info, reply);
} else { } else {
match info.mws[info.count as usize].start(&mut info.req) { match info.mws[info.count as usize].start(&mut info.req) {
Ok(Started::Done) => Ok(Started::Done) => info.count += 1,
info.count += 1, Ok(Started::Response(resp)) => {
Ok(Started::Response(resp)) => return RunMiddlewares::init(info, resp)
return RunMiddlewares::init(info, resp), }
Ok(Started::Future(mut fut)) => Ok(Started::Future(mut fut)) => match fut.poll() {
match fut.poll() { Ok(Async::NotReady) => {
Ok(Async::NotReady) => return PipelineState::Starting(StartMiddlewares {
return PipelineState::Starting(StartMiddlewares { hnd,
hnd, htype, htype,
fut: Some(fut), fut: Some(fut),
_s: PhantomData}), _s: PhantomData,
Ok(Async::Ready(resp)) => { })
if let Some(resp) = resp { }
return RunMiddlewares::init(info, resp); Ok(Async::Ready(resp)) => {
} if let Some(resp) = resp {
info.count += 1; return RunMiddlewares::init(info, resp);
} }
Err(err) => info.count += 1;
return ProcessResponse::init(err.into()), }
}, Err(err) => return ProcessResponse::init(err.into()),
Err(err) => },
return ProcessResponse::init(err.into()), Err(err) => return ProcessResponse::init(err.into()),
} }
} }
} }
@ -274,29 +277,28 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
return Some(RunMiddlewares::init(info, resp)); return Some(RunMiddlewares::init(info, resp));
} }
if info.count == len { if info.count == len {
let reply = unsafe{ let reply = unsafe { &mut *self.hnd.get() }
&mut *self.hnd.get()}.handle(info.req.clone(), self.htype); .handle(info.req.clone(), self.htype);
return Some(WaitingResponse::init(info, reply)); return Some(WaitingResponse::init(info, reply));
} else { } else {
loop { loop {
match info.mws[info.count as usize].start(info.req_mut()) { match info.mws[info.count as usize].start(info.req_mut()) {
Ok(Started::Done) => Ok(Started::Done) => info.count += 1,
info.count += 1,
Ok(Started::Response(resp)) => { Ok(Started::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp)); return Some(RunMiddlewares::init(info, resp));
}, }
Ok(Started::Future(fut)) => { Ok(Started::Future(fut)) => {
self.fut = Some(fut); self.fut = Some(fut);
continue 'outer continue 'outer;
}, }
Err(err) => Err(err) => {
return Some(ProcessResponse::init(err.into())) return Some(ProcessResponse::init(err.into()))
}
} }
} }
} }
} }
Err(err) => Err(err) => return Some(ProcessResponse::init(err.into())),
return Some(ProcessResponse::init(err.into()))
} }
} }
} }
@ -304,31 +306,29 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
// waiting for response // waiting for response
struct WaitingResponse<S, H> { struct WaitingResponse<S, H> {
fut: Box<Future<Item=HttpResponse, Error=Error>>, fut: Box<Future<Item = HttpResponse, Error = Error>>,
_s: PhantomData<S>, _s: PhantomData<S>,
_h: PhantomData<H>, _h: PhantomData<H>,
} }
impl<S: 'static, H> WaitingResponse<S, H> { impl<S: 'static, H> WaitingResponse<S, H> {
#[inline] #[inline]
fn init(info: &mut PipelineInfo<S>, reply: Reply) -> PipelineState<S, H> { fn init(info: &mut PipelineInfo<S>, reply: Reply) -> PipelineState<S, H> {
match reply.into() { match reply.into() {
ReplyItem::Message(resp) => ReplyItem::Message(resp) => RunMiddlewares::init(info, resp),
RunMiddlewares::init(info, resp), ReplyItem::Future(fut) => PipelineState::Handler(WaitingResponse {
ReplyItem::Future(fut) => fut,
PipelineState::Handler( _s: PhantomData,
WaitingResponse { fut, _s: PhantomData, _h: PhantomData }), _h: PhantomData,
}),
} }
} }
fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> { fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> {
match self.fut.poll() { match self.fut.poll() {
Ok(Async::NotReady) => None, Ok(Async::NotReady) => None,
Ok(Async::Ready(response)) => Ok(Async::Ready(response)) => Some(RunMiddlewares::init(info, response)),
Some(RunMiddlewares::init(info, response)), Err(err) => Some(ProcessResponse::init(err.into())),
Err(err) =>
Some(ProcessResponse::init(err.into())),
} }
} }
} }
@ -336,13 +336,12 @@ impl<S: 'static, H> WaitingResponse<S, H> {
/// Middlewares response executor /// Middlewares response executor
struct RunMiddlewares<S, H> { struct RunMiddlewares<S, H> {
curr: usize, curr: usize,
fut: Option<Box<Future<Item=HttpResponse, Error=Error>>>, fut: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
_s: PhantomData<S>, _s: PhantomData<S>,
_h: PhantomData<H>, _h: PhantomData<H>,
} }
impl<S: 'static, H> RunMiddlewares<S, H> { impl<S: 'static, H> RunMiddlewares<S, H> {
fn init(info: &mut PipelineInfo<S>, mut resp: HttpResponse) -> PipelineState<S, H> { fn init(info: &mut PipelineInfo<S>, mut resp: HttpResponse) -> PipelineState<S, H> {
if info.count == 0 { if info.count == 0 {
return ProcessResponse::init(resp); return ProcessResponse::init(resp);
@ -354,21 +353,24 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
resp = match info.mws[curr].response(info.req_mut(), resp) { resp = match info.mws[curr].response(info.req_mut(), resp) {
Err(err) => { Err(err) => {
info.count = (curr + 1) as u16; info.count = (curr + 1) as u16;
return ProcessResponse::init(err.into()) return ProcessResponse::init(err.into());
} }
Ok(Response::Done(r)) => { Ok(Response::Done(r)) => {
curr += 1; curr += 1;
if curr == len { if curr == len {
return ProcessResponse::init(r) return ProcessResponse::init(r);
} else { } else {
r r
} }
}, }
Ok(Response::Future(fut)) => { Ok(Response::Future(fut)) => {
return PipelineState::RunMiddlewares( return PipelineState::RunMiddlewares(RunMiddlewares {
RunMiddlewares { curr, fut: Some(fut), curr,
_s: PhantomData, _h: PhantomData }) fut: Some(fut),
}, _s: PhantomData,
_h: PhantomData,
})
}
}; };
} }
} }
@ -379,15 +381,12 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
loop { loop {
// poll latest fut // poll latest fut
let mut resp = match self.fut.as_mut().unwrap().poll() { let mut resp = match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => { Ok(Async::NotReady) => return None,
return None
}
Ok(Async::Ready(resp)) => { Ok(Async::Ready(resp)) => {
self.curr += 1; self.curr += 1;
resp resp
} }
Err(err) => Err(err) => return Some(ProcessResponse::init(err.into())),
return Some(ProcessResponse::init(err.into())),
}; };
loop { loop {
@ -395,16 +394,15 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
return Some(ProcessResponse::init(resp)); return Some(ProcessResponse::init(resp));
} else { } else {
match info.mws[self.curr].response(info.req_mut(), resp) { match info.mws[self.curr].response(info.req_mut(), resp) {
Err(err) => Err(err) => return Some(ProcessResponse::init(err.into())),
return Some(ProcessResponse::init(err.into())),
Ok(Response::Done(r)) => { Ok(Response::Done(r)) => {
self.curr += 1; self.curr += 1;
resp = r resp = r
}, }
Ok(Response::Future(fut)) => { Ok(Response::Future(fut)) => {
self.fut = Some(fut); self.fut = Some(fut);
break break;
}, }
} }
} }
} }
@ -451,42 +449,56 @@ enum IOState {
} }
impl<S: 'static, H> ProcessResponse<S, H> { impl<S: 'static, H> ProcessResponse<S, H> {
#[inline] #[inline]
fn init(resp: HttpResponse) -> PipelineState<S, H> { fn init(resp: HttpResponse) -> PipelineState<S, H> {
PipelineState::Response( PipelineState::Response(ProcessResponse {
ProcessResponse{ resp, resp,
iostate: IOState::Response, iostate: IOState::Response,
running: RunningState::Running, running: RunningState::Running,
drain: None, _s: PhantomData, _h: PhantomData}) drain: None,
_s: PhantomData,
_h: PhantomData,
})
} }
fn poll_io(mut self, io: &mut Writer, info: &mut PipelineInfo<S>) fn poll_io(
-> Result<PipelineState<S, H>, PipelineState<S, H>> mut self, io: &mut Writer, info: &mut PipelineInfo<S>
{ ) -> Result<PipelineState<S, H>, PipelineState<S, H>> {
loop { loop {
if self.drain.is_none() && self.running != RunningState::Paused { if self.drain.is_none() && self.running != RunningState::Paused {
// if task is paused, write buffer is probably full // if task is paused, write buffer is probably full
'inner: loop { 'inner: loop {
let result = match mem::replace(&mut self.iostate, IOState::Done) { let result = match mem::replace(&mut self.iostate, IOState::Done) {
IOState::Response => { IOState::Response => {
let encoding = self.resp.content_encoding().unwrap_or(info.encoding); let encoding =
self.resp.content_encoding().unwrap_or(info.encoding);
let result = match io.start(info.req_mut().get_inner(), let result = match io.start(
&mut self.resp, encoding) info.req_mut().get_inner(),
{ &mut self.resp,
encoding,
) {
Ok(res) => res, Ok(res) => res,
Err(err) => { Err(err) => {
info.error = Some(err.into()); info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp)) return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
} }
}; };
if let Some(err) = self.resp.error() { if let Some(err) = self.resp.error() {
if self.resp.status().is_server_error() { if self.resp.status().is_server_error() {
error!("Error occured during request handling: {}", err); error!(
"Error occured during request handling: {}",
err
);
} else { } else {
warn!("Error occured during request handling: {}", err); warn!(
"Error occured during request handling: {}",
err
);
} }
if log_enabled!(Debug) { if log_enabled!(Debug) {
debug!("{:?}", err); debug!("{:?}", err);
@ -497,44 +509,48 @@ impl<S: 'static, H> ProcessResponse<S, H> {
match self.resp.replace_body(Body::Empty) { match self.resp.replace_body(Body::Empty) {
Body::Streaming(stream) => { Body::Streaming(stream) => {
self.iostate = IOState::Payload(stream); self.iostate = IOState::Payload(stream);
continue 'inner continue 'inner;
}, }
Body::Actor(ctx) => { Body::Actor(ctx) => {
self.iostate = IOState::Actor(ctx); self.iostate = IOState::Actor(ctx);
continue 'inner continue 'inner;
}, }
_ => (), _ => (),
} }
result result
}, }
IOState::Payload(mut body) => { IOState::Payload(mut body) => match body.poll() {
match body.poll() { Ok(Async::Ready(None)) => {
Ok(Async::Ready(None)) => { if let Err(err) = io.write_eof() {
if let Err(err) = io.write_eof() { info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
}
break;
}
Ok(Async::Ready(Some(chunk))) => {
self.iostate = IOState::Payload(body);
match io.write(chunk.into()) {
Err(err) => {
info.error = Some(err.into()); info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp)) return Ok(FinishingMiddlewares::init(
} info,
break self.resp,
}, ));
Ok(Async::Ready(Some(chunk))) => {
self.iostate = IOState::Payload(body);
match io.write(chunk.into()) {
Err(err) => {
info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp))
},
Ok(result) => result
} }
Ok(result) => result,
} }
Ok(Async::NotReady) => { }
self.iostate = IOState::Payload(body); Ok(Async::NotReady) => {
break self.iostate = IOState::Payload(body);
}, break;
Err(err) => { }
info.error = Some(err); Err(err) => {
return Ok(FinishingMiddlewares::init(info, self.resp)) info.error = Some(err);
} return Ok(FinishingMiddlewares::init(info, self.resp));
} }
}, },
IOState::Actor(mut ctx) => { IOState::Actor(mut ctx) => {
@ -545,7 +561,7 @@ impl<S: 'static, H> ProcessResponse<S, H> {
Ok(Async::Ready(Some(vec))) => { Ok(Async::Ready(Some(vec))) => {
if vec.is_empty() { if vec.is_empty() {
self.iostate = IOState::Actor(ctx); self.iostate = IOState::Actor(ctx);
break break;
} }
let mut res = None; let mut res = None;
for frame in vec { for frame in vec {
@ -555,40 +571,49 @@ impl<S: 'static, H> ProcessResponse<S, H> {
if let Err(err) = io.write_eof() { if let Err(err) = io.write_eof() {
info.error = Some(err.into()); info.error = Some(err.into());
return Ok( return Ok(
FinishingMiddlewares::init(info, self.resp)) FinishingMiddlewares::init(
info,
self.resp,
),
);
} }
break 'inner break 'inner;
}, }
Frame::Chunk(Some(chunk)) => { Frame::Chunk(Some(chunk)) => {
match io.write(chunk) { match io.write(chunk) {
Err(err) => { Err(err) => {
info.error = Some(err.into()); info.error = Some(err.into());
return Ok( return Ok(
FinishingMiddlewares::init(info, self.resp)) FinishingMiddlewares::init(
}, info,
self.resp,
),
);
}
Ok(result) => res = Some(result), Ok(result) => res = Some(result),
} }
}, }
Frame::Drain(fut) => self.drain = Some(fut), Frame::Drain(fut) => self.drain = Some(fut),
} }
} }
self.iostate = IOState::Actor(ctx); self.iostate = IOState::Actor(ctx);
if self.drain.is_some() { if self.drain.is_some() {
self.running.resume(); self.running.resume();
break 'inner break 'inner;
} }
res.unwrap() res.unwrap()
},
Ok(Async::Ready(None)) => {
break
} }
Ok(Async::Ready(None)) => break,
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.iostate = IOState::Actor(ctx); self.iostate = IOState::Actor(ctx);
break break;
} }
Err(err) => { Err(err) => {
info.error = Some(err); info.error = Some(err);
return Ok(FinishingMiddlewares::init(info, self.resp)) return Ok(FinishingMiddlewares::init(
info,
self.resp,
));
} }
} }
} }
@ -598,11 +623,9 @@ impl<S: 'static, H> ProcessResponse<S, H> {
match result { match result {
WriterState::Pause => { WriterState::Pause => {
self.running.pause(); self.running.pause();
break break;
} }
WriterState::Done => { WriterState::Done => self.running.resume(),
self.running.resume()
},
} }
} }
} }
@ -618,17 +641,16 @@ impl<S: 'static, H> ProcessResponse<S, H> {
let _ = tx.send(()); let _ = tx.send(());
} }
// restart io processing // restart io processing
continue continue;
}, }
Ok(Async::NotReady) => Ok(Async::NotReady) => return Err(PipelineState::Response(self)),
return Err(PipelineState::Response(self)),
Err(err) => { Err(err) => {
info.error = Some(err.into()); info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp)) return Ok(FinishingMiddlewares::init(info, self.resp));
} }
} }
} }
break break;
} }
// response is completed // response is completed
@ -638,7 +660,7 @@ impl<S: 'static, H> ProcessResponse<S, H> {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
info.error = Some(err.into()); info.error = Some(err.into());
return Ok(FinishingMiddlewares::init(info, self.resp)) return Ok(FinishingMiddlewares::init(info, self.resp));
} }
} }
self.resp.set_response_size(io.written()); self.resp.set_response_size(io.written());
@ -652,19 +674,22 @@ impl<S: 'static, H> ProcessResponse<S, H> {
/// Middlewares start executor /// Middlewares start executor
struct FinishingMiddlewares<S, H> { struct FinishingMiddlewares<S, H> {
resp: HttpResponse, resp: HttpResponse,
fut: Option<Box<Future<Item=(), Error=Error>>>, fut: Option<Box<Future<Item = (), Error = Error>>>,
_s: PhantomData<S>, _s: PhantomData<S>,
_h: PhantomData<H>, _h: PhantomData<H>,
} }
impl<S: 'static, H> FinishingMiddlewares<S, H> { impl<S: 'static, H> FinishingMiddlewares<S, H> {
fn init(info: &mut PipelineInfo<S>, resp: HttpResponse) -> PipelineState<S, H> { fn init(info: &mut PipelineInfo<S>, resp: HttpResponse) -> PipelineState<S, H> {
if info.count == 0 { if info.count == 0 {
Completed::init(info) Completed::init(info)
} else { } else {
let mut state = FinishingMiddlewares{resp, fut: None, let mut state = FinishingMiddlewares {
_s: PhantomData, _h: PhantomData}; resp,
fut: None,
_s: PhantomData,
_h: PhantomData,
};
if let Some(st) = state.poll(info) { if let Some(st) = state.poll(info) {
st st
} else { } else {
@ -678,12 +703,8 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
// poll latest fut // poll latest fut
let not_ready = if let Some(ref mut fut) = self.fut { let not_ready = if let Some(ref mut fut) = self.fut {
match fut.poll() { match fut.poll() {
Ok(Async::NotReady) => { Ok(Async::NotReady) => true,
true Ok(Async::Ready(())) => false,
},
Ok(Async::Ready(())) => {
false
},
Err(err) => { Err(err) => {
error!("Middleware finish error: {}", err); error!("Middleware finish error: {}", err);
false false
@ -701,12 +722,12 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
match info.mws[info.count as usize].finish(info.req_mut(), &self.resp) { match info.mws[info.count as usize].finish(info.req_mut(), &self.resp) {
Finished::Done => { Finished::Done => {
if info.count == 0 { if info.count == 0 {
return Some(Completed::init(info)) return Some(Completed::init(info));
} }
} }
Finished::Future(fut) => { Finished::Future(fut) => {
self.fut = Some(fut); self.fut = Some(fut);
}, }
} }
} }
} }
@ -716,7 +737,6 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
struct Completed<S, H>(PhantomData<S>, PhantomData<H>); struct Completed<S, H>(PhantomData<S>, PhantomData<H>);
impl<S, H> Completed<S, H> { impl<S, H> Completed<S, H> {
#[inline] #[inline]
fn init(info: &mut PipelineInfo<S>) -> PipelineState<S, H> { fn init(info: &mut PipelineInfo<S>) -> PipelineState<S, H> {
if let Some(ref err) = info.error { if let Some(ref err) = info.error {
@ -745,15 +765,23 @@ mod tests {
use super::*; use super::*;
use actix::*; use actix::*;
use context::HttpContext; use context::HttpContext;
use tokio_core::reactor::Core;
use futures::future::{lazy, result}; use futures::future::{lazy, result};
use tokio_core::reactor::Core;
impl<S, H> PipelineState<S, H> { impl<S, H> PipelineState<S, H> {
fn is_none(&self) -> Option<bool> { fn is_none(&self) -> Option<bool> {
if let PipelineState::None = *self { Some(true) } else { None } if let PipelineState::None = *self {
Some(true)
} else {
None
}
} }
fn completed(self) -> Option<Completed<S, H>> { fn completed(self) -> Option<Completed<S, H>> {
if let PipelineState::Completed(c) = self { Some(c) } else { None } if let PipelineState::Completed(c) = self {
Some(c)
} else {
None
}
} }
} }
@ -764,28 +792,35 @@ mod tests {
#[test] #[test]
fn test_completed() { fn test_completed() {
Core::new().unwrap().run(lazy(|| { Core::new()
let mut info = PipelineInfo::new(HttpRequest::default()); .unwrap()
Completed::<(), Inner<()>>::init(&mut info).is_none().unwrap(); .run(lazy(|| {
let mut info = PipelineInfo::new(HttpRequest::default());
Completed::<(), Inner<()>>::init(&mut info)
.is_none()
.unwrap();
let req = HttpRequest::default(); let req = HttpRequest::default();
let mut ctx = HttpContext::new(req.clone(), MyActor); let mut ctx = HttpContext::new(req.clone(), MyActor);
let addr: Addr<Unsync, _> = ctx.address(); let addr: Addr<Unsync, _> = ctx.address();
let mut info = PipelineInfo::new(req); let mut info = PipelineInfo::new(req);
info.context = Some(Box::new(ctx)); info.context = Some(Box::new(ctx));
let mut state = Completed::<(), Inner<()>>::init(&mut info).completed().unwrap(); let mut state = Completed::<(), Inner<()>>::init(&mut info)
.completed()
.unwrap();
assert!(state.poll(&mut info).is_none()); assert!(state.poll(&mut info).is_none());
let pp = Pipeline(info, PipelineState::Completed(state)); let pp = Pipeline(info, PipelineState::Completed(state));
assert!(!pp.is_done()); assert!(!pp.is_done());
let Pipeline(mut info, st) = pp; let Pipeline(mut info, st) = pp;
let mut st = st.completed().unwrap(); let mut st = st.completed().unwrap();
drop(addr); drop(addr);
assert!(st.poll(&mut info).unwrap().is_none().unwrap()); assert!(st.poll(&mut info).unwrap().is_none().unwrap());
result(Ok::<_, ()>(())) result(Ok::<_, ()>(()))
})).unwrap(); }))
.unwrap();
} }
} }

View File

@ -1,20 +1,18 @@
//! Route match predicates //! Route match predicates
#![allow(non_snake_case)] #![allow(non_snake_case)]
use std::marker::PhantomData;
use http; use http;
use http::{header, HttpTryFrom}; use http::{header, HttpTryFrom};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use std::marker::PhantomData;
/// Trait defines resource route predicate. /// Trait defines resource route predicate.
/// Predicate can modify request object. It is also possible to /// Predicate can modify request object. It is also possible to
/// to store extra attributes on request by using `Extensions` container, /// to store extra attributes on request by using `Extensions` container,
/// Extensions container available via `HttpRequest::extensions()` method. /// Extensions container available via `HttpRequest::extensions()` method.
pub trait Predicate<S> { pub trait Predicate<S> {
/// Check if request matches predicate /// Check if request matches predicate
fn check(&self, &mut HttpRequest<S>) -> bool; fn check(&self, &mut HttpRequest<S>) -> bool;
} }
/// Return predicate that matches if any of supplied predicate matches. /// Return predicate that matches if any of supplied predicate matches.
@ -30,8 +28,7 @@ pub trait Predicate<S> {
/// .f(|r| HttpResponse::MethodNotAllowed())); /// .f(|r| HttpResponse::MethodNotAllowed()));
/// } /// }
/// ``` /// ```
pub fn Any<S: 'static, P: Predicate<S> + 'static>(pred: P) -> AnyPredicate<S> pub fn Any<S: 'static, P: Predicate<S> + 'static>(pred: P) -> AnyPredicate<S> {
{
AnyPredicate(vec![Box::new(pred)]) AnyPredicate(vec![Box::new(pred)])
} }
@ -50,7 +47,7 @@ impl<S: 'static> Predicate<S> for AnyPredicate<S> {
fn check(&self, req: &mut HttpRequest<S>) -> bool { fn check(&self, req: &mut HttpRequest<S>) -> bool {
for p in &self.0 { for p in &self.0 {
if p.check(req) { if p.check(req) {
return true return true;
} }
} }
false false
@ -90,7 +87,7 @@ impl<S: 'static> Predicate<S> for AllPredicate<S> {
fn check(&self, req: &mut HttpRequest<S>) -> bool { fn check(&self, req: &mut HttpRequest<S>) -> bool {
for p in &self.0 { for p in &self.0 {
if !p.check(req) { if !p.check(req) {
return false return false;
} }
} }
true true
@ -98,8 +95,7 @@ impl<S: 'static> Predicate<S> for AllPredicate<S> {
} }
/// Return predicate that matches if supplied predicate does not match. /// Return predicate that matches if supplied predicate does not match.
pub fn Not<S: 'static, P: Predicate<S> + 'static>(pred: P) -> NotPredicate<S> pub fn Not<S: 'static, P: Predicate<S> + 'static>(pred: P) -> NotPredicate<S> {
{
NotPredicate(Box::new(pred)) NotPredicate(Box::new(pred))
} }
@ -172,21 +168,29 @@ pub fn Method<S: 'static>(method: http::Method) -> MethodPredicate<S> {
MethodPredicate(method, PhantomData) MethodPredicate(method, PhantomData)
} }
/// Return predicate that matches if request contains specified header and value. /// Return predicate that matches if request contains specified header and
pub fn Header<S: 'static>(name: &'static str, value: &'static str) -> HeaderPredicate<S> /// value.
{ pub fn Header<S: 'static>(
HeaderPredicate(header::HeaderName::try_from(name).unwrap(), name: &'static str, value: &'static str
header::HeaderValue::from_static(value), ) -> HeaderPredicate<S> {
PhantomData) HeaderPredicate(
header::HeaderName::try_from(name).unwrap(),
header::HeaderValue::from_static(value),
PhantomData,
)
} }
#[doc(hidden)] #[doc(hidden)]
pub struct HeaderPredicate<S>(header::HeaderName, header::HeaderValue, PhantomData<S>); pub struct HeaderPredicate<S>(
header::HeaderName,
header::HeaderValue,
PhantomData<S>,
);
impl<S: 'static> Predicate<S> for HeaderPredicate<S> { impl<S: 'static> Predicate<S> for HeaderPredicate<S> {
fn check(&self, req: &mut HttpRequest<S>) -> bool { fn check(&self, req: &mut HttpRequest<S>) -> bool {
if let Some(val) = req.headers().get(&self.0) { if let Some(val) = req.headers().get(&self.0) {
return val == self.1 return val == self.1;
} }
false false
} }
@ -195,17 +199,24 @@ impl<S: 'static> Predicate<S> for HeaderPredicate<S> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::str::FromStr;
use http::{Uri, Version, Method};
use http::header::{self, HeaderMap}; use http::header::{self, HeaderMap};
use http::{Method, Uri, Version};
use std::str::FromStr;
#[test] #[test]
fn test_header() { fn test_header() {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::TRANSFER_ENCODING, headers.insert(
header::HeaderValue::from_static("chunked")); header::TRANSFER_ENCODING,
header::HeaderValue::from_static("chunked"),
);
let mut req = HttpRequest::new( let mut req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
let pred = Header("transfer-encoding", "chunked"); let pred = Header("transfer-encoding", "chunked");
assert!(pred.check(&mut req)); assert!(pred.check(&mut req));
@ -220,11 +231,19 @@ mod tests {
#[test] #[test]
fn test_methods() { fn test_methods() {
let mut req = HttpRequest::new( let mut req = HttpRequest::new(
Method::GET, Uri::from_str("/").unwrap(), Method::GET,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
let mut req2 = HttpRequest::new( let mut req2 = HttpRequest::new(
Method::POST, Uri::from_str("/").unwrap(), Method::POST,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Get().check(&mut req)); assert!(Get().check(&mut req));
assert!(!Get().check(&mut req2)); assert!(!Get().check(&mut req2));
@ -232,44 +251,72 @@ mod tests {
assert!(!Post().check(&mut req)); assert!(!Post().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::PUT, Uri::from_str("/").unwrap(), Method::PUT,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Put().check(&mut r)); assert!(Put().check(&mut r));
assert!(!Put().check(&mut req)); assert!(!Put().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::DELETE, Uri::from_str("/").unwrap(), Method::DELETE,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Delete().check(&mut r)); assert!(Delete().check(&mut r));
assert!(!Delete().check(&mut req)); assert!(!Delete().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::HEAD, Uri::from_str("/").unwrap(), Method::HEAD,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Head().check(&mut r)); assert!(Head().check(&mut r));
assert!(!Head().check(&mut req)); assert!(!Head().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::OPTIONS, Uri::from_str("/").unwrap(), Method::OPTIONS,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Options().check(&mut r)); assert!(Options().check(&mut r));
assert!(!Options().check(&mut req)); assert!(!Options().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::CONNECT, Uri::from_str("/").unwrap(), Method::CONNECT,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Connect().check(&mut r)); assert!(Connect().check(&mut r));
assert!(!Connect().check(&mut req)); assert!(!Connect().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::PATCH, Uri::from_str("/").unwrap(), Method::PATCH,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Patch().check(&mut r)); assert!(Patch().check(&mut r));
assert!(!Patch().check(&mut req)); assert!(!Patch().check(&mut req));
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::TRACE, Uri::from_str("/").unwrap(), Method::TRACE,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Trace().check(&mut r)); assert!(Trace().check(&mut r));
assert!(!Trace().check(&mut req)); assert!(!Trace().check(&mut req));
} }
@ -277,8 +324,12 @@ mod tests {
#[test] #[test]
fn test_preds() { fn test_preds() {
let mut r = HttpRequest::new( let mut r = HttpRequest::new(
Method::TRACE, Uri::from_str("/").unwrap(), Method::TRACE,
Version::HTTP_11, HeaderMap::new(), None); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert!(Not(Get()).check(&mut r)); assert!(Not(Get()).check(&mut r));
assert!(!Not(Trace()).check(&mut r)); assert!(!Not(Trace()).check(&mut r));

View File

@ -1,15 +1,15 @@
use std::rc::Rc;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc;
use smallvec::SmallVec;
use http::{Method, StatusCode}; use http::{Method, StatusCode};
use smallvec::SmallVec;
use pred; use handler::{FromRequest, Handler, Reply, Responder};
use route::Route;
use handler::{Reply, Handler, Responder, FromRequest};
use middleware::Middleware;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::Middleware;
use pred;
use route::Route;
/// *Resource* is an entry in route table which corresponds to requested URL. /// *Resource* is an entry in route table which corresponds to requested URL.
/// ///
@ -18,8 +18,8 @@ use httpresponse::HttpResponse;
/// and list of predicates (objects that implement `Predicate` trait). /// and list of predicates (objects that implement `Predicate` trait).
/// Route uses builder-like pattern for configuration. /// Route uses builder-like pattern for configuration.
/// During request handling, resource object iterate through all routes /// During request handling, resource object iterate through all routes
/// and check all predicates for specific route, if request matches all predicates route /// and check all predicates for specific route, if request matches all
/// route considered matched and route handler get called. /// predicates route route considered matched and route handler get called.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_web; /// # extern crate actix_web;
@ -31,7 +31,7 @@ use httpresponse::HttpResponse;
/// "/", |r| r.method(http::Method::GET).f(|r| HttpResponse::Ok())) /// "/", |r| r.method(http::Method::GET).f(|r| HttpResponse::Ok()))
/// .finish(); /// .finish();
/// } /// }
pub struct ResourceHandler<S=()> { pub struct ResourceHandler<S = ()> {
name: String, name: String,
state: PhantomData<S>, state: PhantomData<S>,
routes: SmallVec<[Route<S>; 3]>, routes: SmallVec<[Route<S>; 3]>,
@ -44,18 +44,19 @@ impl<S> Default for ResourceHandler<S> {
name: String::new(), name: String::new(),
state: PhantomData, state: PhantomData,
routes: SmallVec::new(), routes: SmallVec::new(),
middlewares: Rc::new(Vec::new()) } middlewares: Rc::new(Vec::new()),
}
} }
} }
impl<S> ResourceHandler<S> { impl<S> ResourceHandler<S> {
pub(crate) fn default_not_found() -> Self { pub(crate) fn default_not_found() -> Self {
ResourceHandler { ResourceHandler {
name: String::new(), name: String::new(),
state: PhantomData, state: PhantomData,
routes: SmallVec::new(), routes: SmallVec::new(),
middlewares: Rc::new(Vec::new()) } middlewares: Rc::new(Vec::new()),
}
} }
/// Set resource name /// Set resource name
@ -69,9 +70,9 @@ impl<S> ResourceHandler<S> {
} }
impl<S: 'static> ResourceHandler<S> { impl<S: 'static> ResourceHandler<S> {
/// Register a new route and return mutable reference to *Route* object. /// Register a new route and return mutable reference to *Route* object.
/// *Route* is used for route configuration, i.e. adding predicates, setting up handler. /// *Route* is used for route configuration, i.e. adding predicates,
/// setting up handler.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_web; /// # extern crate actix_web;
@ -131,7 +132,10 @@ impl<S: 'static> ResourceHandler<S> {
/// ``` /// ```
pub fn method(&mut self, method: Method) -> &mut Route<S> { pub fn method(&mut self, method: Method) -> &mut Route<S> {
self.routes.push(Route::default()); self.routes.push(Route::default());
self.routes.last_mut().unwrap().filter(pred::Method(method)) self.routes
.last_mut()
.unwrap()
.filter(pred::Method(method))
} }
/// Register a new route and add handler object. /// Register a new route and add handler object.
@ -154,8 +158,9 @@ impl<S: 'static> ResourceHandler<S> {
/// Application::resource("/", |r| r.route().f(index) /// Application::resource("/", |r| r.route().f(index)
/// ``` /// ```
pub fn f<F, R>(&mut self, handler: F) pub fn f<F, R>(&mut self, handler: F)
where F: Fn(HttpRequest<S>) -> R + 'static, where
R: Responder + 'static, F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
{ {
self.routes.push(Route::default()); self.routes.push(Route::default());
self.routes.last_mut().unwrap().f(handler) self.routes.last_mut().unwrap().f(handler)
@ -169,9 +174,10 @@ impl<S: 'static> ResourceHandler<S> {
/// Application::resource("/", |r| r.route().with(index) /// Application::resource("/", |r| r.route().with(index)
/// ``` /// ```
pub fn with<T, F, R>(&mut self, handler: F) pub fn with<T, F, R>(&mut self, handler: F)
where F: Fn(T) -> R + 'static, where
R: Responder + 'static, F: Fn(T) -> R + 'static,
T: FromRequest<S> + 'static, R: Responder + 'static,
T: FromRequest<S> + 'static,
{ {
self.routes.push(Route::default()); self.routes.push(Route::default());
self.routes.last_mut().unwrap().with(handler); self.routes.last_mut().unwrap().with(handler);
@ -182,13 +188,14 @@ impl<S: 'static> ResourceHandler<S> {
/// This is similar to `App's` middlewares, but /// This is similar to `App's` middlewares, but
/// middlewares get invoked on resource level. /// middlewares get invoked on resource level.
pub fn middleware<M: Middleware<S>>(&mut self, mw: M) { pub fn middleware<M: Middleware<S>>(&mut self, mw: M) {
Rc::get_mut(&mut self.middlewares).unwrap().push(Box::new(mw)); Rc::get_mut(&mut self.middlewares)
.unwrap()
.push(Box::new(mw));
} }
pub(crate) fn handle(&mut self, pub(crate) fn handle(
mut req: HttpRequest<S>, &mut self, mut req: HttpRequest<S>, default: Option<&mut ResourceHandler<S>>
default: Option<&mut ResourceHandler<S>>) -> Reply ) -> Reply {
{
for route in &mut self.routes { for route in &mut self.routes {
if route.check(&mut req) { if route.check(&mut req) {
return if self.middlewares.is_empty() { return if self.middlewares.is_empty() {

View File

@ -1,17 +1,18 @@
use futures::{Async, Future, Poll};
use std::marker::PhantomData;
use std::mem; use std::mem;
use std::rc::Rc; use std::rc::Rc;
use std::marker::PhantomData;
use futures::{Async, Future, Poll};
use error::Error; use error::Error;
use pred::Predicate; use handler::{AsyncHandler, FromRequest, Handler, Reply, ReplyItem, Responder,
RouteHandler, WrapHandler};
use http::StatusCode; use http::StatusCode;
use handler::{Reply, ReplyItem, Handler, FromRequest,
Responder, RouteHandler, AsyncHandler, WrapHandler};
use middleware::{Middleware, Response as MiddlewareResponse, Started as MiddlewareStarted};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use with::{With, With2, With3, ExtractorConfig}; use middleware::{Middleware, Response as MiddlewareResponse,
Started as MiddlewareStarted};
use pred::Predicate;
use with::{ExtractorConfig, With, With2, With3};
/// Resource route definition /// Resource route definition
/// ///
@ -23,7 +24,6 @@ pub struct Route<S> {
} }
impl<S: 'static> Default for Route<S> { impl<S: 'static> Default for Route<S> {
fn default() -> Route<S> { fn default() -> Route<S> {
Route { Route {
preds: Vec::new(), preds: Vec::new(),
@ -33,12 +33,11 @@ impl<S: 'static> Default for Route<S> {
} }
impl<S: 'static> Route<S> { impl<S: 'static> Route<S> {
#[inline] #[inline]
pub(crate) fn check(&self, req: &mut HttpRequest<S>) -> bool { pub(crate) fn check(&self, req: &mut HttpRequest<S>) -> bool {
for pred in &self.preds { for pred in &self.preds {
if !pred.check(req) { if !pred.check(req) {
return false return false;
} }
} }
true true
@ -50,9 +49,9 @@ impl<S: 'static> Route<S> {
} }
#[inline] #[inline]
pub(crate) fn compose(&mut self, pub(crate) fn compose(
req: HttpRequest<S>, &mut self, req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>
mws: Rc<Vec<Box<Middleware<S>>>>) -> Reply { ) -> Reply {
Reply::async(Compose::new(req, mws, self.handler.clone())) Reply::async(Compose::new(req, mws, self.handler.clone()))
} }
@ -86,18 +85,20 @@ impl<S: 'static> Route<S> {
/// Set handler function. Usually call to this method is last call /// Set handler function. Usually call to this method is last call
/// during route configuration, so it does not return reference to self. /// during route configuration, so it does not return reference to self.
pub fn f<F, R>(&mut self, handler: F) pub fn f<F, R>(&mut self, handler: F)
where F: Fn(HttpRequest<S>) -> R + 'static, where
R: Responder + 'static, F: Fn(HttpRequest<S>) -> R + 'static,
R: Responder + 'static,
{ {
self.handler = InnerHandler::new(handler); self.handler = InnerHandler::new(handler);
} }
/// Set async handler function. /// Set async handler function.
pub fn a<H, R, F, E>(&mut self, handler: H) pub fn a<H, R, F, E>(&mut self, handler: H)
where H: Fn(HttpRequest<S>) -> F + 'static, where
F: Future<Item=R, Error=E> + 'static, H: Fn(HttpRequest<S>) -> F + 'static,
R: Responder + 'static, F: Future<Item = R, Error = E> + 'static,
E: Into<Error> + 'static R: Responder + 'static,
E: Into<Error> + 'static,
{ {
self.handler = InnerHandler::async(handler); self.handler = InnerHandler::async(handler);
} }
@ -128,9 +129,10 @@ impl<S: 'static> Route<S> {
/// } /// }
/// ``` /// ```
pub fn with<T, F, R>(&mut self, handler: F) -> ExtractorConfig<S, T> pub fn with<T, F, R>(&mut self, handler: F) -> ExtractorConfig<S, T>
where F: Fn(T) -> R + 'static, where
R: Responder + 'static, F: Fn(T) -> R + 'static,
T: FromRequest<S> + 'static, R: Responder + 'static,
T: FromRequest<S> + 'static,
{ {
let cfg = ExtractorConfig::default(); let cfg = ExtractorConfig::default();
self.h(With::new(handler, Clone::clone(&cfg))); self.h(With::new(handler, Clone::clone(&cfg)));
@ -167,43 +169,58 @@ impl<S: 'static> Route<S> {
/// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor /// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor
/// } /// }
/// ``` /// ```
pub fn with2<T1, T2, F, R>(&mut self, handler: F) pub fn with2<T1, T2, F, R>(
-> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>) &mut self, handler: F
where F: Fn(T1, T2) -> R + 'static, ) -> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>)
R: Responder + 'static, where
T1: FromRequest<S> + 'static, F: Fn(T1, T2) -> R + 'static,
T2: FromRequest<S> + 'static, R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
{ {
let cfg1 = ExtractorConfig::default(); let cfg1 = ExtractorConfig::default();
let cfg2 = ExtractorConfig::default(); let cfg2 = ExtractorConfig::default();
self.h(With2::new(handler, Clone::clone(&cfg1), Clone::clone(&cfg2))); self.h(With2::new(
handler,
Clone::clone(&cfg1),
Clone::clone(&cfg2),
));
(cfg1, cfg2) (cfg1, cfg2)
} }
/// Set handler function, use request extractor for all paramters. /// Set handler function, use request extractor for all paramters.
pub fn with3<T1, T2, T3, F, R>(&mut self, handler: F) pub fn with3<T1, T2, T3, F, R>(
-> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>, ExtractorConfig<S, T3>) &mut self, handler: F
where F: Fn(T1, T2, T3) -> R + 'static, ) -> (
R: Responder + 'static, ExtractorConfig<S, T1>,
T1: FromRequest<S> + 'static, ExtractorConfig<S, T2>,
T2: FromRequest<S> + 'static, ExtractorConfig<S, T3>,
T3: FromRequest<S> + 'static, )
where
F: Fn(T1, T2, T3) -> R + 'static,
R: Responder + 'static,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static,
{ {
let cfg1 = ExtractorConfig::default(); let cfg1 = ExtractorConfig::default();
let cfg2 = ExtractorConfig::default(); let cfg2 = ExtractorConfig::default();
let cfg3 = ExtractorConfig::default(); let cfg3 = ExtractorConfig::default();
self.h(With3::new( self.h(With3::new(
handler, Clone::clone(&cfg1), Clone::clone(&cfg2), Clone::clone(&cfg3))); handler,
Clone::clone(&cfg1),
Clone::clone(&cfg2),
Clone::clone(&cfg3),
));
(cfg1, cfg2, cfg3) (cfg1, cfg2, cfg3)
} }
} }
/// `RouteHandler` wrapper. This struct is required because it needs to be shared /// `RouteHandler` wrapper. This struct is required because it needs to be
/// for resource level middlewares. /// shared for resource level middlewares.
struct InnerHandler<S>(Rc<Box<RouteHandler<S>>>); struct InnerHandler<S>(Rc<Box<RouteHandler<S>>>);
impl<S: 'static> InnerHandler<S> { impl<S: 'static> InnerHandler<S> {
#[inline] #[inline]
fn new<H: Handler<S>>(h: H) -> Self { fn new<H: Handler<S>>(h: H) -> Self {
InnerHandler(Rc::new(Box::new(WrapHandler::new(h)))) InnerHandler(Rc::new(Box::new(WrapHandler::new(h))))
@ -211,10 +228,11 @@ impl<S: 'static> InnerHandler<S> {
#[inline] #[inline]
fn async<H, R, F, E>(h: H) -> Self fn async<H, R, F, E>(h: H) -> Self
where H: Fn(HttpRequest<S>) -> F + 'static, where
F: Future<Item=R, Error=E> + 'static, H: Fn(HttpRequest<S>) -> F + 'static,
R: Responder + 'static, F: Future<Item = R, Error = E> + 'static,
E: Into<Error> + 'static R: Responder + 'static,
E: Into<Error> + 'static,
{ {
InnerHandler(Rc::new(Box::new(AsyncHandler::new(h)))) InnerHandler(Rc::new(Box::new(AsyncHandler::new(h))))
} }
@ -237,7 +255,6 @@ impl<S> Clone for InnerHandler<S> {
} }
} }
/// Compose resource level middlewares with route handler. /// Compose resource level middlewares with route handler.
struct Compose<S: 'static> { struct Compose<S: 'static> {
info: ComposeInfo<S>, info: ComposeInfo<S>,
@ -270,14 +287,18 @@ impl<S: 'static> ComposeState<S> {
} }
impl<S: 'static> Compose<S> { impl<S: 'static> Compose<S> {
fn new(req: HttpRequest<S>, fn new(
mws: Rc<Vec<Box<Middleware<S>>>>, req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, handler: InnerHandler<S>
handler: InnerHandler<S>) -> Self ) -> Self {
{ let mut info = ComposeInfo {
let mut info = ComposeInfo { count: 0, req, mws, handler }; count: 0,
req,
mws,
handler,
};
let state = StartMiddlewares::init(&mut info); let state = StartMiddlewares::init(&mut info);
Compose {state, info} Compose { state, info }
} }
} }
@ -289,12 +310,12 @@ impl<S> Future for Compose<S> {
loop { loop {
if let ComposeState::Response(ref mut resp) = self.state { if let ComposeState::Response(ref mut resp) = self.state {
let resp = resp.resp.take().unwrap(); let resp = resp.resp.take().unwrap();
return Ok(Async::Ready(resp)) return Ok(Async::Ready(resp));
} }
if let Some(state) = self.state.poll(&mut self.info) { if let Some(state) = self.state.poll(&mut self.info) {
self.state = state; self.state = state;
} else { } else {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
} }
} }
@ -306,51 +327,47 @@ struct StartMiddlewares<S> {
_s: PhantomData<S>, _s: PhantomData<S>,
} }
type Fut = Box<Future<Item=Option<HttpResponse>, Error=Error>>; type Fut = Box<Future<Item = Option<HttpResponse>, Error = Error>>;
impl<S: 'static> StartMiddlewares<S> { impl<S: 'static> StartMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> {
let len = info.mws.len(); let len = info.mws.len();
loop { loop {
if info.count == len { if info.count == len {
let reply = info.handler.handle(info.req.clone()); let reply = info.handler.handle(info.req.clone());
return WaitingResponse::init(info, reply) return WaitingResponse::init(info, reply);
} else { } else {
match info.mws[info.count].start(&mut info.req) { match info.mws[info.count].start(&mut info.req) {
Ok(MiddlewareStarted::Done) => Ok(MiddlewareStarted::Done) => info.count += 1,
info.count += 1, Ok(MiddlewareStarted::Response(resp)) => {
Ok(MiddlewareStarted::Response(resp)) => return RunMiddlewares::init(info, resp)
return RunMiddlewares::init(info, resp), }
Ok(MiddlewareStarted::Future(mut fut)) => Ok(MiddlewareStarted::Future(mut fut)) => match fut.poll() {
match fut.poll() { Ok(Async::NotReady) => {
Ok(Async::NotReady) => return ComposeState::Starting(StartMiddlewares {
return ComposeState::Starting(StartMiddlewares { fut: Some(fut),
fut: Some(fut), _s: PhantomData,
_s: PhantomData}), })
Ok(Async::Ready(resp)) => { }
if let Some(resp) = resp { Ok(Async::Ready(resp)) => {
return RunMiddlewares::init(info, resp); if let Some(resp) = resp {
} return RunMiddlewares::init(info, resp);
info.count += 1;
} }
Err(err) => info.count += 1;
return Response::init(err.into()), }
}, Err(err) => return Response::init(err.into()),
Err(err) => },
return Response::init(err.into()), Err(err) => return Response::init(err.into()),
} }
} }
} }
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
{
let len = info.mws.len(); let len = info.mws.len();
'outer: loop { 'outer: loop {
match self.fut.as_mut().unwrap().poll() { match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => Ok(Async::NotReady) => return None,
return None,
Ok(Async::Ready(resp)) => { Ok(Async::Ready(resp)) => {
info.count += 1; info.count += 1;
if let Some(resp) = resp { if let Some(resp) = resp {
@ -362,23 +379,20 @@ impl<S: 'static> StartMiddlewares<S> {
} else { } else {
loop { loop {
match info.mws[info.count].start(&mut info.req) { match info.mws[info.count].start(&mut info.req) {
Ok(MiddlewareStarted::Done) => Ok(MiddlewareStarted::Done) => info.count += 1,
info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => { Ok(MiddlewareStarted::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp)); return Some(RunMiddlewares::init(info, resp));
}, }
Ok(MiddlewareStarted::Future(fut)) => { Ok(MiddlewareStarted::Future(fut)) => {
self.fut = Some(fut); self.fut = Some(fut);
continue 'outer continue 'outer;
}, }
Err(err) => Err(err) => return Some(Response::init(err.into())),
return Some(Response::init(err.into()))
} }
} }
} }
} }
Err(err) => Err(err) => return Some(Response::init(err.into())),
return Some(Response::init(err.into()))
} }
} }
} }
@ -386,44 +400,39 @@ impl<S: 'static> StartMiddlewares<S> {
// waiting for response // waiting for response
struct WaitingResponse<S> { struct WaitingResponse<S> {
fut: Box<Future<Item=HttpResponse, Error=Error>>, fut: Box<Future<Item = HttpResponse, Error = Error>>,
_s: PhantomData<S>, _s: PhantomData<S>,
} }
impl<S: 'static> WaitingResponse<S> { impl<S: 'static> WaitingResponse<S> {
#[inline] #[inline]
fn init(info: &mut ComposeInfo<S>, reply: Reply) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>, reply: Reply) -> ComposeState<S> {
match reply.into() { match reply.into() {
ReplyItem::Message(resp) => ReplyItem::Message(resp) => RunMiddlewares::init(info, resp),
RunMiddlewares::init(info, resp), ReplyItem::Future(fut) => ComposeState::Handler(WaitingResponse {
ReplyItem::Future(fut) => fut,
ComposeState::Handler( _s: PhantomData,
WaitingResponse { fut, _s: PhantomData }), }),
} }
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
match self.fut.poll() { match self.fut.poll() {
Ok(Async::NotReady) => None, Ok(Async::NotReady) => None,
Ok(Async::Ready(response)) => Ok(Async::Ready(response)) => Some(RunMiddlewares::init(info, response)),
Some(RunMiddlewares::init(info, response)), Err(err) => Some(Response::init(err.into())),
Err(err) =>
Some(Response::init(err.into())),
} }
} }
} }
/// Middlewares response executor /// Middlewares response executor
struct RunMiddlewares<S> { struct RunMiddlewares<S> {
curr: usize, curr: usize,
fut: Option<Box<Future<Item=HttpResponse, Error=Error>>>, fut: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
_s: PhantomData<S>, _s: PhantomData<S>,
} }
impl<S: 'static> RunMiddlewares<S> { impl<S: 'static> RunMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> {
let mut curr = 0; let mut curr = 0;
let len = info.mws.len(); let len = info.mws.len();
@ -432,40 +441,39 @@ impl<S: 'static> RunMiddlewares<S> {
resp = match info.mws[curr].response(&mut info.req, resp) { resp = match info.mws[curr].response(&mut info.req, resp) {
Err(err) => { Err(err) => {
info.count = curr + 1; info.count = curr + 1;
return Response::init(err.into()) return Response::init(err.into());
}, }
Ok(MiddlewareResponse::Done(r)) => { Ok(MiddlewareResponse::Done(r)) => {
curr += 1; curr += 1;
if curr == len { if curr == len {
return Response::init(r) return Response::init(r);
} else { } else {
r r
} }
}, }
Ok(MiddlewareResponse::Future(fut)) => { Ok(MiddlewareResponse::Future(fut)) => {
return ComposeState::RunMiddlewares( return ComposeState::RunMiddlewares(RunMiddlewares {
RunMiddlewares { curr, fut: Some(fut), _s: PhantomData }) curr,
}, fut: Some(fut),
_s: PhantomData,
})
}
}; };
} }
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
{
let len = info.mws.len(); let len = info.mws.len();
loop { loop {
// poll latest fut // poll latest fut
let mut resp = match self.fut.as_mut().unwrap().poll() { let mut resp = match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => { Ok(Async::NotReady) => return None,
return None
}
Ok(Async::Ready(resp)) => { Ok(Async::Ready(resp)) => {
self.curr += 1; self.curr += 1;
resp resp
} }
Err(err) => Err(err) => return Some(Response::init(err.into())),
return Some(Response::init(err.into())),
}; };
loop { loop {
@ -473,16 +481,15 @@ impl<S: 'static> RunMiddlewares<S> {
return Some(Response::init(resp)); return Some(Response::init(resp));
} else { } else {
match info.mws[self.curr].response(&mut info.req, resp) { match info.mws[self.curr].response(&mut info.req, resp) {
Err(err) => Err(err) => return Some(Response::init(err.into())),
return Some(Response::init(err.into())),
Ok(MiddlewareResponse::Done(r)) => { Ok(MiddlewareResponse::Done(r)) => {
self.curr += 1; self.curr += 1;
resp = r resp = r
}, }
Ok(MiddlewareResponse::Future(fut)) => { Ok(MiddlewareResponse::Future(fut)) => {
self.fut = Some(fut); self.fut = Some(fut);
break break;
}, }
} }
} }
} }
@ -496,9 +503,10 @@ struct Response<S> {
} }
impl<S: 'static> Response<S> { impl<S: 'static> Response<S> {
fn init(resp: HttpResponse) -> ComposeState<S> { fn init(resp: HttpResponse) -> ComposeState<S> {
ComposeState::Response( ComposeState::Response(Response {
Response{resp: Some(resp), _s: PhantomData}) resp: Some(resp),
_s: PhantomData,
})
} }
} }

View File

@ -1,15 +1,15 @@
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem; use std::mem;
use std::rc::Rc; use std::rc::Rc;
use std::hash::{Hash, Hasher};
use std::collections::HashMap;
use regex::{Regex, escape};
use percent_encoding::percent_decode; use percent_encoding::percent_decode;
use regex::{escape, Regex};
use param::Params;
use error::UrlGenerationError; use error::UrlGenerationError;
use resource::ResourceHandler;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use param::Params;
use resource::ResourceHandler;
use server::ServerSettings; use server::ServerSettings;
/// Interface for application router. /// Interface for application router.
@ -25,11 +25,10 @@ struct Inner {
impl Router { impl Router {
/// Create new router /// Create new router
pub fn new<S>(prefix: &str, pub fn new<S>(
settings: ServerSettings, prefix: &str, settings: ServerSettings,
map: Vec<(Resource, Option<ResourceHandler<S>>)>) map: Vec<(Resource, Option<ResourceHandler<S>>)>,
-> (Router, Vec<ResourceHandler<S>>) ) -> (Router, Vec<ResourceHandler<S>>) {
{
let prefix = prefix.trim().trim_right_matches('/').to_owned(); let prefix = prefix.trim().trim_right_matches('/').to_owned();
let mut named = HashMap::new(); let mut named = HashMap::new();
let mut patterns = Vec::new(); let mut patterns = Vec::new();
@ -48,8 +47,16 @@ impl Router {
} }
let prefix_len = prefix.len(); let prefix_len = prefix.len();
(Router(Rc::new( (
Inner{ prefix, prefix_len, named, patterns, srv: settings })), resources) Router(Rc::new(Inner {
prefix,
prefix_len,
named,
patterns,
srv: settings,
})),
resources,
)
} }
/// Router prefix /// Router prefix
@ -71,16 +78,18 @@ impl Router {
/// Query for matched resource /// Query for matched resource
pub fn recognize<S>(&self, req: &mut HttpRequest<S>) -> Option<usize> { pub fn recognize<S>(&self, req: &mut HttpRequest<S>) -> Option<usize> {
if self.0.prefix_len > req.path().len() { if self.0.prefix_len > req.path().len() {
return None return None;
} }
let path: &str = unsafe{mem::transmute(&req.path()[self.0.prefix_len..])}; let path: &str = unsafe { mem::transmute(&req.path()[self.0.prefix_len..]) };
let route_path = if path.is_empty() { "/" } else { path }; let route_path = if path.is_empty() { "/" } else { path };
let p = percent_decode(route_path.as_bytes()).decode_utf8().unwrap(); let p = percent_decode(route_path.as_bytes())
.decode_utf8()
.unwrap();
for (idx, pattern) in self.0.patterns.iter().enumerate() { for (idx, pattern) in self.0.patterns.iter().enumerate() {
if pattern.match_with_params(p.as_ref(), req.match_info_mut()) { if pattern.match_with_params(p.as_ref(), req.match_info_mut()) {
req.set_resource(idx); req.set_resource(idx);
return Some(idx) return Some(idx);
} }
} }
None None
@ -97,7 +106,7 @@ impl Router {
for pattern in &self.0.patterns { for pattern in &self.0.patterns {
if pattern.is_match(path) { if pattern.is_match(path) {
return true return true;
} }
} }
false false
@ -105,12 +114,14 @@ impl Router {
/// Build named resource path. /// Build named resource path.
/// ///
/// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method.url_for) /// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method.
/// for detailed information. /// url_for) for detailed information.
pub fn resource_path<U, I>(&self, name: &str, elements: U) pub fn resource_path<U, I>(
-> Result<String, UrlGenerationError> &self, name: &str, elements: U
where U: IntoIterator<Item=I>, ) -> Result<String, UrlGenerationError>
I: AsRef<str>, where
U: IntoIterator<Item = I>,
I: AsRef<str>,
{ {
if let Some(pattern) = self.0.named.get(name) { if let Some(pattern) = self.0.named.get(name) {
pattern.0.resource_path(self, elements) pattern.0.resource_path(self, elements)
@ -196,7 +207,7 @@ impl Resource {
let tp = if is_dynamic { let tp = if is_dynamic {
let re = match Regex::new(&pattern) { let re = match Regex::new(&pattern) {
Ok(re) => re, Ok(re) => re,
Err(err) => panic!("Wrong path pattern: \"{}\" {}", path, err) Err(err) => panic!("Wrong path pattern: \"{}\" {}", path, err),
}; };
let names = re.capture_names() let names = re.capture_names()
.filter_map(|name| name.map(|name| name.to_owned())) .filter_map(|name| name.map(|name| name.to_owned()))
@ -237,9 +248,9 @@ impl Resource {
} }
} }
pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>) pub fn match_with_params<'a>(
-> bool &'a self, path: &'a str, params: &'a mut Params<'a>
{ ) -> bool {
match self.tp { match self.tp {
PatternType::Static(ref s) => s == path, PatternType::Static(ref s) => s == path,
PatternType::Dynamic(ref re, ref names) => { PatternType::Dynamic(ref re, ref names) => {
@ -248,7 +259,7 @@ impl Resource {
for capture in captures.iter() { for capture in captures.iter() {
if let Some(ref m) = capture { if let Some(ref m) = capture {
if idx != 0 { if idx != 0 {
params.add(names[idx-1].as_str(), m.as_str()); params.add(names[idx - 1].as_str(), m.as_str());
} }
idx += 1; idx += 1;
} }
@ -262,10 +273,12 @@ impl Resource {
} }
/// Build reousrce path. /// Build reousrce path.
pub fn resource_path<U, I>(&self, router: &Router, elements: U) pub fn resource_path<U, I>(
-> Result<String, UrlGenerationError> &self, router: &Router, elements: U
where U: IntoIterator<Item=I>, ) -> Result<String, UrlGenerationError>
I: AsRef<str>, where
U: IntoIterator<Item = I>,
I: AsRef<str>,
{ {
let mut iter = elements.into_iter(); let mut iter = elements.into_iter();
let mut path = if self.rtp != ResourceType::External { let mut path = if self.rtp != ResourceType::External {
@ -280,7 +293,7 @@ impl Resource {
if let Some(val) = iter.next() { if let Some(val) = iter.next() {
path.push_str(val.as_ref()) path.push_str(val.as_ref())
} else { } else {
return Err(UrlGenerationError::NotEnoughElements) return Err(UrlGenerationError::NotEnoughElements);
} }
} }
} }
@ -374,20 +387,35 @@ mod tests {
#[test] #[test]
fn test_recognizer() { fn test_recognizer() {
let routes = vec![ let routes = vec![
(Resource::new("", "/name"), (
Some(ResourceHandler::default())), Resource::new("", "/name"),
(Resource::new("", "/name/{val}"), Some(ResourceHandler::default()),
Some(ResourceHandler::default())), ),
(Resource::new("", "/name/{val}/index.html"), (
Some(ResourceHandler::default())), Resource::new("", "/name/{val}"),
(Resource::new("", "/file/{file}.{ext}"), Some(ResourceHandler::default()),
Some(ResourceHandler::default())), ),
(Resource::new("", "/v{val}/{val2}/index.html"), (
Some(ResourceHandler::default())), Resource::new("", "/name/{val}/index.html"),
(Resource::new("", "/v/{tail:.*}"), Some(ResourceHandler::default()),
Some(ResourceHandler::default())), ),
(Resource::new("", "{test}/index.html"), (
Some(ResourceHandler::default()))]; Resource::new("", "/file/{file}.{ext}"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/v{val}/{val2}/index.html"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/v/{tail:.*}"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "{test}/index.html"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes); let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/name").finish(); let mut req = TestRequest::with_uri("/name").finish();
@ -415,7 +443,10 @@ mod tests {
let mut req = TestRequest::with_uri("/v/blah-blah/index.html").finish(); let mut req = TestRequest::with_uri("/v/blah-blah/index.html").finish();
assert_eq!(rec.recognize(&mut req), Some(5)); assert_eq!(rec.recognize(&mut req), Some(5));
assert_eq!(req.match_info().get("tail").unwrap(), "blah-blah/index.html"); assert_eq!(
req.match_info().get("tail").unwrap(),
"blah-blah/index.html"
);
let mut req = TestRequest::with_uri("/bbb/index.html").finish(); let mut req = TestRequest::with_uri("/bbb/index.html").finish();
assert_eq!(rec.recognize(&mut req), Some(6)); assert_eq!(rec.recognize(&mut req), Some(6));
@ -425,8 +456,15 @@ mod tests {
#[test] #[test]
fn test_recognizer_2() { fn test_recognizer_2() {
let routes = vec![ let routes = vec![
(Resource::new("", "/index.json"), Some(ResourceHandler::default())), (
(Resource::new("", "/{source}.json"), Some(ResourceHandler::default()))]; Resource::new("", "/index.json"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/{source}.json"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes); let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/index.json").finish(); let mut req = TestRequest::with_uri("/index.json").finish();
@ -439,8 +477,15 @@ mod tests {
#[test] #[test]
fn test_recognizer_with_prefix() { fn test_recognizer_with_prefix() {
let routes = vec![ let routes = vec![
(Resource::new("", "/name"), Some(ResourceHandler::default())), (
(Resource::new("", "/name/{val}"), Some(ResourceHandler::default()))]; Resource::new("", "/name"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/name/{val}"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("/test", ServerSettings::default(), routes); let (rec, _) = Router::new::<()>("/test", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/name").finish(); let mut req = TestRequest::with_uri("/name").finish();
@ -456,8 +501,15 @@ mod tests {
// same patterns // same patterns
let routes = vec![ let routes = vec![
(Resource::new("", "/name"), Some(ResourceHandler::default())), (
(Resource::new("", "/name/{val}"), Some(ResourceHandler::default()))]; Resource::new("", "/name"),
Some(ResourceHandler::default()),
),
(
Resource::new("", "/name/{val}"),
Some(ResourceHandler::default()),
),
];
let (rec, _) = Router::new::<()>("/test2", ServerSettings::default(), routes); let (rec, _) = Router::new::<()>("/test2", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/name").finish(); let mut req = TestRequest::with_uri("/name").finish();
@ -525,18 +577,25 @@ mod tests {
#[test] #[test]
fn test_request_resource() { fn test_request_resource() {
let routes = vec![ let routes = vec![
(Resource::new("r1", "/index.json"), Some(ResourceHandler::default())), (
(Resource::new("r2", "/test.json"), Some(ResourceHandler::default()))]; Resource::new("r1", "/index.json"),
Some(ResourceHandler::default()),
),
(
Resource::new("r2", "/test.json"),
Some(ResourceHandler::default()),
),
];
let (router, _) = Router::new::<()>("", ServerSettings::default(), routes); let (router, _) = Router::new::<()>("", ServerSettings::default(), routes);
let mut req = TestRequest::with_uri("/index.json") let mut req =
.finish_with_router(router.clone()); TestRequest::with_uri("/index.json").finish_with_router(router.clone());
assert_eq!(router.recognize(&mut req), Some(0)); assert_eq!(router.recognize(&mut req), Some(0));
let resource = req.resource(); let resource = req.resource();
assert_eq!(resource.name(), "r1"); assert_eq!(resource.name(), "r1");
let mut req = TestRequest::with_uri("/test.json") let mut req =
.finish_with_router(router.clone()); TestRequest::with_uri("/test.json").finish_with_router(router.clone());
assert_eq!(router.recognize(&mut req), Some(1)); assert_eq!(router.recognize(&mut req), Some(1));
let resource = req.resource(); let resource = req.resource();
assert_eq!(resource.name(), "r2"); assert_eq!(resource.name(), "r2");

View File

@ -1,17 +1,16 @@
use std::{ptr, mem, time, io}; use std::net::{Shutdown, SocketAddr};
use std::rc::Rc; use std::rc::Rc;
use std::net::{SocketAddr, Shutdown}; use std::{io, mem, ptr, time};
use bytes::{Bytes, BytesMut, Buf, BufMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{Future, Poll, Async}; use futures::{Async, Future, Poll};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use super::{h1, h2, utils, HttpHandler, IoStream};
use super::settings::WorkerSettings; use super::settings::WorkerSettings;
use super::{utils, HttpHandler, IoStream, h1, h2};
const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0";
enum HttpProtocol<T: IoStream, H: 'static> { enum HttpProtocol<T: IoStream, H: 'static> {
H1(h1::Http1<T, H>), H1(h1::Http1<T, H>),
H2(h2::Http2<T, H>), H2(h2::Http2<T, H>),
@ -24,27 +23,47 @@ enum ProtocolKind {
} }
#[doc(hidden)] #[doc(hidden)]
pub struct HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static { pub struct HttpChannel<T, H>
where
T: IoStream,
H: HttpHandler + 'static,
{
proto: Option<HttpProtocol<T, H>>, proto: Option<HttpProtocol<T, H>>,
node: Option<Node<HttpChannel<T, H>>>, node: Option<Node<HttpChannel<T, H>>>,
} }
impl<T, H> HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static impl<T, H> HttpChannel<T, H>
where
T: IoStream,
H: HttpHandler + 'static,
{ {
pub(crate) fn new(settings: Rc<WorkerSettings<H>>, pub(crate) fn new(
mut io: T, peer: Option<SocketAddr>, http2: bool) -> HttpChannel<T, H> settings: Rc<WorkerSettings<H>>, mut io: T, peer: Option<SocketAddr>,
{ http2: bool,
) -> HttpChannel<T, H> {
settings.add_channel(); settings.add_channel();
let _ = io.set_nodelay(true); let _ = io.set_nodelay(true);
if http2 { if http2 {
HttpChannel { HttpChannel {
node: None, proto: Some(HttpProtocol::H2( node: None,
h2::Http2::new(settings, io, peer, Bytes::new()))) } proto: Some(HttpProtocol::H2(h2::Http2::new(
settings,
io,
peer,
Bytes::new(),
))),
}
} else { } else {
HttpChannel { HttpChannel {
node: None, proto: Some(HttpProtocol::Unknown( node: None,
settings, peer, io, BytesMut::with_capacity(4096))) } proto: Some(HttpProtocol::Unknown(
settings,
peer,
io,
BytesMut::with_capacity(4096),
)),
}
} }
} }
@ -55,15 +74,16 @@ impl<T, H> HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static
let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0))); let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0)));
let _ = IoStream::shutdown(io, Shutdown::Both); let _ = IoStream::shutdown(io, Shutdown::Both);
} }
Some(HttpProtocol::H2(ref mut h2)) => { Some(HttpProtocol::H2(ref mut h2)) => h2.shutdown(),
h2.shutdown()
}
_ => (), _ => (),
} }
} }
} }
impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'static impl<T, H> Future for HttpChannel<T, H>
where
T: IoStream,
H: HttpHandler + 'static,
{ {
type Item = (); type Item = ();
type Error = (); type Error = ();
@ -73,12 +93,15 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
let el = self as *mut _; let el = self as *mut _;
self.node = Some(Node::new(el)); self.node = Some(Node::new(el));
let _ = match self.proto { let _ = match self.proto {
Some(HttpProtocol::H1(ref mut h1)) => Some(HttpProtocol::H1(ref mut h1)) => self.node
self.node.as_ref().map(|n| h1.settings().head().insert(n)), .as_ref()
Some(HttpProtocol::H2(ref mut h2)) => .map(|n| h1.settings().head().insert(n)),
self.node.as_ref().map(|n| h2.settings().head().insert(n)), Some(HttpProtocol::H2(ref mut h2)) => self.node
Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) => .as_ref()
self.node.as_ref().map(|n| settings.head().insert(n)), .map(|n| h2.settings().head().insert(n)),
Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) => {
self.node.as_ref().map(|n| settings.head().insert(n))
}
None => unreachable!(), None => unreachable!(),
}; };
} }
@ -90,30 +113,35 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
Ok(Async::Ready(())) | Err(_) => { Ok(Async::Ready(())) | Err(_) => {
h1.settings().remove_channel(); h1.settings().remove_channel();
self.node.as_mut().map(|n| n.remove()); self.node.as_mut().map(|n| n.remove());
}, }
_ => (), _ => (),
} }
return result return result;
}, }
Some(HttpProtocol::H2(ref mut h2)) => { Some(HttpProtocol::H2(ref mut h2)) => {
let result = h2.poll(); let result = h2.poll();
match result { match result {
Ok(Async::Ready(())) | Err(_) => { Ok(Async::Ready(())) | Err(_) => {
h2.settings().remove_channel(); h2.settings().remove_channel();
self.node.as_mut().map(|n| n.remove()); self.node.as_mut().map(|n| n.remove());
}, }
_ => (), _ => (),
} }
return result return result;
}, }
Some(HttpProtocol::Unknown(ref mut settings, _, ref mut io, ref mut buf)) => { Some(HttpProtocol::Unknown(
ref mut settings,
_,
ref mut io,
ref mut buf,
)) => {
match utils::read_from_io(io, buf) { match utils::read_from_io(io, buf) {
Ok(Async::Ready(0)) | Err(_) => { Ok(Async::Ready(0)) | Err(_) => {
debug!("Ignored premature client disconnection"); debug!("Ignored premature client disconnection");
settings.remove_channel(); settings.remove_channel();
self.node.as_mut().map(|n| n.remove()); self.node.as_mut().map(|n| n.remove());
return Err(()) return Err(());
}, }
_ => (), _ => (),
} }
@ -126,7 +154,7 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
} else { } else {
return Ok(Async::NotReady); return Ok(Async::NotReady);
} }
}, }
None => unreachable!(), None => unreachable!(),
}; };
@ -134,30 +162,36 @@ impl<T, H> Future for HttpChannel<T, H> where T: IoStream, H: HttpHandler + 'sta
if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() { if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() {
match kind { match kind {
ProtocolKind::Http1 => { ProtocolKind::Http1 => {
self.proto = Some( self.proto = Some(HttpProtocol::H1(h1::Http1::new(
HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf))); settings,
return self.poll() io,
}, addr,
buf,
)));
return self.poll();
}
ProtocolKind::Http2 => { ProtocolKind::Http2 => {
self.proto = Some( self.proto = Some(HttpProtocol::H2(h2::Http2::new(
HttpProtocol::H2(h2::Http2::new(settings, io, addr, buf.freeze()))); settings,
return self.poll() io,
}, addr,
buf.freeze(),
)));
return self.poll();
}
} }
} }
unreachable!() unreachable!()
} }
} }
pub(crate) struct Node<T> pub(crate) struct Node<T> {
{
next: Option<*mut Node<()>>, next: Option<*mut Node<()>>,
prev: Option<*mut Node<()>>, prev: Option<*mut Node<()>>,
element: *mut T, element: *mut T,
} }
impl<T> Node<T> impl<T> Node<T> {
{
fn new(el: *mut T) -> Self { fn new(el: *mut T) -> Self {
Node { Node {
next: None, next: None,
@ -194,9 +228,7 @@ impl<T> Node<T>
} }
} }
impl Node<()> { impl Node<()> {
pub(crate) fn head() -> Self { pub(crate) fn head() -> Self {
Node { Node {
next: None, next: None,
@ -205,7 +237,11 @@ impl Node<()> {
} }
} }
pub(crate) fn traverse<T, H>(&self) where T: IoStream, H: HttpHandler + 'static { pub(crate) fn traverse<T, H>(&self)
where
T: IoStream,
H: HttpHandler + 'static,
{
let mut next = self.next.as_ref(); let mut next = self.next.as_ref();
loop { loop {
if let Some(n) = next { if let Some(n) = next {
@ -214,30 +250,39 @@ impl Node<()> {
next = n.next.as_ref(); next = n.next.as_ref();
if !n.element.is_null() { if !n.element.is_null() {
let ch: &mut HttpChannel<T, H> = mem::transmute( let ch: &mut HttpChannel<T, H> =
&mut *(n.element as *mut _)); mem::transmute(&mut *(n.element as *mut _));
ch.shutdown(); ch.shutdown();
} }
} }
} else { } else {
return return;
} }
} }
} }
} }
/// Wrapper for `AsyncRead + AsyncWrite` types /// Wrapper for `AsyncRead + AsyncWrite` types
pub(crate) struct WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { pub(crate) struct WrapperStream<T>
io: T, where
T: AsyncRead + AsyncWrite + 'static,
{
io: T,
} }
impl<T> WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { impl<T> WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
pub fn new(io: T) -> Self { pub fn new(io: T) -> Self {
WrapperStream{ io } WrapperStream { io }
} }
} }
impl<T> IoStream for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { impl<T> IoStream for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline] #[inline]
fn shutdown(&mut self, _: Shutdown) -> io::Result<()> { fn shutdown(&mut self, _: Shutdown) -> io::Result<()> {
Ok(()) Ok(())
@ -252,14 +297,20 @@ impl<T> IoStream for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static
} }
} }
impl<T> io::Read for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { impl<T> io::Read for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline] #[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.io.read(buf) self.io.read(buf)
} }
} }
impl<T> io::Write for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { impl<T> io::Write for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline] #[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.write(buf) self.io.write(buf)
@ -270,14 +321,20 @@ impl<T> io::Write for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static
} }
} }
impl<T> AsyncRead for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { impl<T> AsyncRead for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline] #[inline]
fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> { fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.io.read_buf(buf) self.io.read_buf(buf)
} }
} }
impl<T> AsyncWrite for WrapperStream<T> where T: AsyncRead + AsyncWrite + 'static { impl<T> AsyncWrite for WrapperStream<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
#[inline] #[inline]
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
self.io.shutdown() self.io.shutdown()

View File

@ -1,25 +1,24 @@
use std::{io, cmp, mem};
use std::io::{Read, Write};
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::str::FromStr; use std::str::FromStr;
use std::{cmp, io, mem};
use bytes::{Bytes, BytesMut, BufMut}; #[cfg(feature = "brotli")]
use http::{Version, Method, HttpTryFrom}; use brotli2::write::{BrotliDecoder, BrotliEncoder};
use http::header::{HeaderMap, HeaderValue, use bytes::{BufMut, Bytes, BytesMut};
ACCEPT_ENCODING, CONNECTION,
CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use flate2::Compression; use flate2::Compression;
use flate2::read::GzDecoder; use flate2::read::GzDecoder;
use flate2::write::{GzEncoder, DeflateDecoder, DeflateEncoder}; use flate2::write::{DeflateDecoder, DeflateEncoder, GzEncoder};
#[cfg(feature="brotli")] use http::header::{HeaderMap, HeaderValue, ACCEPT_ENCODING, CONNECTION,
use brotli2::write::{BrotliDecoder, BrotliEncoder}; CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{HttpTryFrom, Method, Version};
use header::ContentEncoding; use body::{Binary, Body};
use body::{Body, Binary};
use error::PayloadError; use error::PayloadError;
use header::ContentEncoding;
use httprequest::HttpInnerMessage; use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use payload::{PayloadSender, PayloadWriter, PayloadStatus}; use payload::{PayloadSender, PayloadStatus, PayloadWriter};
use super::shared::SharedBytes; use super::shared::SharedBytes;
@ -29,7 +28,6 @@ pub(crate) enum PayloadType {
} }
impl PayloadType { impl PayloadType {
pub fn new(headers: &HeaderMap, sender: PayloadSender) -> PayloadType { pub fn new(headers: &HeaderMap, sender: PayloadSender) -> PayloadType {
// check content-encoding // check content-encoding
let enc = if let Some(enc) = headers.get(CONTENT_ENCODING) { let enc = if let Some(enc) = headers.get(CONTENT_ENCODING) {
@ -43,8 +41,9 @@ impl PayloadType {
}; };
match enc { match enc {
ContentEncoding::Auto | ContentEncoding::Identity => ContentEncoding::Auto | ContentEncoding::Identity => {
PayloadType::Sender(sender), PayloadType::Sender(sender)
}
_ => PayloadType::Encoding(Box::new(EncodedPayload::new(sender, enc))), _ => PayloadType::Encoding(Box::new(EncodedPayload::new(sender, enc))),
} }
} }
@ -84,7 +83,6 @@ impl PayloadWriter for PayloadType {
} }
} }
/// Payload wrapper with content decompression support /// Payload wrapper with content decompression support
pub(crate) struct EncodedPayload { pub(crate) struct EncodedPayload {
inner: PayloadSender, inner: PayloadSender,
@ -94,12 +92,15 @@ pub(crate) struct EncodedPayload {
impl EncodedPayload { impl EncodedPayload {
pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload { pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload {
EncodedPayload{ inner, error: false, payload: PayloadStream::new(enc) } EncodedPayload {
inner,
error: false,
payload: PayloadStream::new(enc),
}
} }
} }
impl PayloadWriter for EncodedPayload { impl PayloadWriter for EncodedPayload {
fn set_error(&mut self, err: PayloadError) { fn set_error(&mut self, err: PayloadError) {
self.inner.set_error(err) self.inner.set_error(err)
} }
@ -110,7 +111,7 @@ impl PayloadWriter for EncodedPayload {
Err(err) => { Err(err) => {
self.error = true; self.error = true;
self.set_error(PayloadError::Io(err)); self.set_error(PayloadError::Io(err));
}, }
Ok(value) => { Ok(value) => {
if let Some(b) = value { if let Some(b) = value {
self.inner.feed_data(b); self.inner.feed_data(b);
@ -123,7 +124,7 @@ impl PayloadWriter for EncodedPayload {
fn feed_data(&mut self, data: Bytes) { fn feed_data(&mut self, data: Bytes) {
if self.error { if self.error {
return return;
} }
match self.payload.feed_data(data) { match self.payload.feed_data(data) {
@ -145,7 +146,7 @@ impl PayloadWriter for EncodedPayload {
pub(crate) enum Decoder { pub(crate) enum Decoder {
Deflate(Box<DeflateDecoder<Writer>>), Deflate(Box<DeflateDecoder<Writer>>),
Gzip(Option<Box<GzDecoder<Wrapper>>>), Gzip(Option<Box<GzDecoder<Wrapper>>>),
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
Br(Box<BrotliDecoder<Writer>>), Br(Box<BrotliDecoder<Writer>>),
Identity, Identity,
} }
@ -190,7 +191,9 @@ pub(crate) struct Writer {
impl Writer { impl Writer {
fn new() -> Writer { fn new() -> Writer {
Writer{buf: BytesMut::with_capacity(8192)} Writer {
buf: BytesMut::with_capacity(8192),
}
} }
fn take(&mut self) -> Bytes { fn take(&mut self) -> Bytes {
self.buf.take().freeze() self.buf.take().freeze()
@ -216,65 +219,64 @@ pub(crate) struct PayloadStream {
impl PayloadStream { impl PayloadStream {
pub fn new(enc: ContentEncoding) -> PayloadStream { pub fn new(enc: ContentEncoding) -> PayloadStream {
let dec = match enc { let dec = match enc {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
ContentEncoding::Br => Decoder::Br( ContentEncoding::Br => {
Box::new(BrotliDecoder::new(Writer::new()))), Decoder::Br(Box::new(BrotliDecoder::new(Writer::new())))
ContentEncoding::Deflate => Decoder::Deflate( }
Box::new(DeflateDecoder::new(Writer::new()))), ContentEncoding::Deflate => {
Decoder::Deflate(Box::new(DeflateDecoder::new(Writer::new())))
}
ContentEncoding::Gzip => Decoder::Gzip(None), ContentEncoding::Gzip => Decoder::Gzip(None),
_ => Decoder::Identity, _ => Decoder::Identity,
}; };
PayloadStream{ decoder: dec, dst: BytesMut::new() } PayloadStream {
decoder: dec,
dst: BytesMut::new(),
}
} }
} }
impl PayloadStream { impl PayloadStream {
pub fn feed_eof(&mut self) -> io::Result<Option<Bytes>> { pub fn feed_eof(&mut self) -> io::Result<Option<Bytes>> {
match self.decoder { match self.decoder {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
Decoder::Br(ref mut decoder) => { Decoder::Br(ref mut decoder) => match decoder.finish() {
match decoder.finish() { Ok(mut writer) => {
Ok(mut writer) => { let b = writer.take();
let b = writer.take(); if !b.is_empty() {
if !b.is_empty() { Ok(Some(b))
Ok(Some(b)) } else {
} else { Ok(None)
Ok(None) }
}
},
Err(e) => Err(e),
} }
Err(e) => Err(e),
}, },
Decoder::Gzip(ref mut decoder) => { Decoder::Gzip(ref mut decoder) => {
if let Some(ref mut decoder) = *decoder { if let Some(ref mut decoder) = *decoder {
decoder.as_mut().get_mut().eof = true; decoder.as_mut().get_mut().eof = true;
self.dst.reserve(8192); self.dst.reserve(8192);
match decoder.read(unsafe{self.dst.bytes_mut()}) { match decoder.read(unsafe { self.dst.bytes_mut() }) {
Ok(n) => { Ok(n) => {
unsafe{self.dst.advance_mut(n)}; unsafe { self.dst.advance_mut(n) };
return Ok(Some(self.dst.take().freeze())) return Ok(Some(self.dst.take().freeze()));
} }
Err(e) => Err(e) => return Err(e),
return Err(e),
} }
} else { } else {
Ok(None) Ok(None)
} }
}, }
Decoder::Deflate(ref mut decoder) => { Decoder::Deflate(ref mut decoder) => match decoder.try_finish() {
match decoder.try_finish() { Ok(_) => {
Ok(_) => { let b = decoder.get_mut().take();
let b = decoder.get_mut().take(); if !b.is_empty() {
if !b.is_empty() { Ok(Some(b))
Ok(Some(b)) } else {
} else { Ok(None)
Ok(None) }
}
},
Err(e) => Err(e),
} }
Err(e) => Err(e),
}, },
Decoder::Identity => Ok(None), Decoder::Identity => Ok(None),
} }
@ -282,66 +284,67 @@ impl PayloadStream {
pub fn feed_data(&mut self, data: Bytes) -> io::Result<Option<Bytes>> { pub fn feed_data(&mut self, data: Bytes) -> io::Result<Option<Bytes>> {
match self.decoder { match self.decoder {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
Decoder::Br(ref mut decoder) => { Decoder::Br(ref mut decoder) => match decoder.write_all(&data) {
match decoder.write_all(&data) { Ok(_) => {
Ok(_) => { decoder.flush()?;
decoder.flush()?; let b = decoder.get_mut().take();
let b = decoder.get_mut().take(); if !b.is_empty() {
if !b.is_empty() { Ok(Some(b))
Ok(Some(b)) } else {
} else { Ok(None)
Ok(None) }
}
},
Err(e) => Err(e)
} }
Err(e) => Err(e),
}, },
Decoder::Gzip(ref mut decoder) => { Decoder::Gzip(ref mut decoder) => {
if decoder.is_none() { if decoder.is_none() {
*decoder = Some( *decoder = Some(Box::new(GzDecoder::new(Wrapper {
Box::new(GzDecoder::new( buf: BytesMut::from(data),
Wrapper{buf: BytesMut::from(data), eof: false}))); eof: false,
})));
} else { } else {
let _ = decoder.as_mut().unwrap().write(&data); let _ = decoder.as_mut().unwrap().write(&data);
} }
loop { loop {
self.dst.reserve(8192); self.dst.reserve(8192);
match decoder.as_mut() match decoder
.as_mut().unwrap().read(unsafe{self.dst.bytes_mut()}) .as_mut()
.as_mut()
.unwrap()
.read(unsafe { self.dst.bytes_mut() })
{ {
Ok(n) => { Ok(n) => {
if n != 0 { if n != 0 {
unsafe{self.dst.advance_mut(n)}; unsafe { self.dst.advance_mut(n) };
} }
if n == 0 { if n == 0 {
return Ok(Some(self.dst.take().freeze())); return Ok(Some(self.dst.take().freeze()));
} }
} }
Err(e) => { Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock && !self.dst.is_empty() if e.kind() == io::ErrorKind::WouldBlock
&& !self.dst.is_empty()
{ {
return Ok(Some(self.dst.take().freeze())); return Ok(Some(self.dst.take().freeze()));
} }
return Err(e) return Err(e);
} }
} }
} }
}, }
Decoder::Deflate(ref mut decoder) => { Decoder::Deflate(ref mut decoder) => match decoder.write_all(&data) {
match decoder.write_all(&data) { Ok(_) => {
Ok(_) => { decoder.flush()?;
decoder.flush()?; let b = decoder.get_mut().take();
let b = decoder.get_mut().take(); if !b.is_empty() {
if !b.is_empty() { Ok(Some(b))
Ok(Some(b)) } else {
} else { Ok(None)
Ok(None) }
}
},
Err(e) => Err(e),
} }
Err(e) => Err(e),
}, },
Decoder::Identity => Ok(Some(data)), Decoder::Identity => Ok(Some(data)),
} }
@ -351,33 +354,33 @@ impl PayloadStream {
pub(crate) enum ContentEncoder { pub(crate) enum ContentEncoder {
Deflate(DeflateEncoder<TransferEncoding>), Deflate(DeflateEncoder<TransferEncoding>),
Gzip(GzEncoder<TransferEncoding>), Gzip(GzEncoder<TransferEncoding>),
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
Br(BrotliEncoder<TransferEncoding>), Br(BrotliEncoder<TransferEncoding>),
Identity(TransferEncoding), Identity(TransferEncoding),
} }
impl ContentEncoder { impl ContentEncoder {
pub fn empty(bytes: SharedBytes) -> ContentEncoder { pub fn empty(bytes: SharedBytes) -> ContentEncoder {
ContentEncoder::Identity(TransferEncoding::eof(bytes)) ContentEncoder::Identity(TransferEncoding::eof(bytes))
} }
pub fn for_server(buf: SharedBytes, pub fn for_server(
req: &HttpInnerMessage, buf: SharedBytes, req: &HttpInnerMessage, resp: &mut HttpResponse,
resp: &mut HttpResponse, response_encoding: ContentEncoding,
response_encoding: ContentEncoding) -> ContentEncoder ) -> ContentEncoder {
{
let version = resp.version().unwrap_or_else(|| req.version); let version = resp.version().unwrap_or_else(|| req.version);
let is_head = req.method == Method::HEAD; let is_head = req.method == Method::HEAD;
let mut body = resp.replace_body(Body::Empty); let mut body = resp.replace_body(Body::Empty);
let has_body = match body { let has_body = match body {
Body::Empty => false, Body::Empty => false,
Body::Binary(ref bin) => Body::Binary(ref bin) => {
!(response_encoding == ContentEncoding::Auto && bin.len() < 96), !(response_encoding == ContentEncoding::Auto && bin.len() < 96)
}
_ => true, _ => true,
}; };
// Enable content encoding only if response does not contain Content-Encoding header // Enable content encoding only if response does not contain Content-Encoding
// header
let mut encoding = if has_body { let mut encoding = if has_body {
let encoding = match response_encoding { let encoding = match response_encoding {
ContentEncoding::Auto => { ContentEncoding::Auto => {
@ -396,7 +399,9 @@ impl ContentEncoder {
}; };
if encoding.is_compression() { if encoding.is_compression() {
resp.headers_mut().insert( resp.headers_mut().insert(
CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
} }
encoding encoding
} else { } else {
@ -409,23 +414,27 @@ impl ContentEncoder {
resp.headers_mut().remove(CONTENT_LENGTH); resp.headers_mut().remove(CONTENT_LENGTH);
} }
TransferEncoding::length(0, buf) TransferEncoding::length(0, buf)
}, }
Body::Binary(ref mut bytes) => { Body::Binary(ref mut bytes) => {
if !(encoding == ContentEncoding::Identity if !(encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto) || encoding == ContentEncoding::Auto)
{ {
let tmp = SharedBytes::default(); let tmp = SharedBytes::default();
let transfer = TransferEncoding::eof(tmp.clone()); let transfer = TransferEncoding::eof(tmp.clone());
let mut enc = match encoding { let mut enc = match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate( ContentEncoding::Deflate => ContentEncoder::Deflate(
DeflateEncoder::new(transfer, Compression::fast())), DeflateEncoder::new(transfer, Compression::fast()),
ContentEncoding::Gzip => ContentEncoder::Gzip( ),
GzEncoder::new(transfer, Compression::fast())), ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new(
#[cfg(feature="brotli")] transfer,
ContentEncoding::Br => ContentEncoder::Br( Compression::fast(),
BrotliEncoder::new(transfer, 3)), )),
#[cfg(feature = "brotli")]
ContentEncoding::Br => {
ContentEncoder::Br(BrotliEncoder::new(transfer, 3))
}
ContentEncoding::Identity => ContentEncoder::Identity(transfer), ContentEncoding::Identity => ContentEncoder::Identity(transfer),
ContentEncoding::Auto => unreachable!() ContentEncoding::Auto => unreachable!(),
}; };
// TODO return error! // TODO return error!
let _ = enc.write(bytes.clone()); let _ = enc.write(bytes.clone());
@ -438,7 +447,9 @@ impl ContentEncoder {
let mut b = BytesMut::new(); let mut b = BytesMut::new();
let _ = write!(b, "{}", bytes.len()); let _ = write!(b, "{}", bytes.len());
resp.headers_mut().insert( resp.headers_mut().insert(
CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); CONTENT_LENGTH,
HeaderValue::try_from(b.freeze()).unwrap(),
);
} else { } else {
// resp.headers_mut().remove(CONTENT_LENGTH); // resp.headers_mut().remove(CONTENT_LENGTH);
} }
@ -449,8 +460,8 @@ impl ContentEncoder {
if version == Version::HTTP_2 { if version == Version::HTTP_2 {
error!("Connection upgrade is forbidden for HTTP/2"); error!("Connection upgrade is forbidden for HTTP/2");
} else { } else {
resp.headers_mut().insert( resp.headers_mut()
CONNECTION, HeaderValue::from_static("upgrade")); .insert(CONNECTION, HeaderValue::from_static("upgrade"));
} }
if encoding != ContentEncoding::Identity { if encoding != ContentEncoding::Identity {
encoding = ContentEncoding::Identity; encoding = ContentEncoding::Identity;
@ -470,20 +481,24 @@ impl ContentEncoder {
} }
match encoding { match encoding {
ContentEncoding::Deflate => ContentEncoder::Deflate( ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
DeflateEncoder::new(transfer, Compression::fast())), transfer,
ContentEncoding::Gzip => ContentEncoder::Gzip( Compression::fast(),
GzEncoder::new(transfer, Compression::fast())), )),
#[cfg(feature="brotli")] ContentEncoding::Gzip => {
ContentEncoding::Br => ContentEncoder::Br( ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::fast()))
BrotliEncoder::new(transfer, 3)), }
ContentEncoding::Identity | ContentEncoding::Auto => #[cfg(feature = "brotli")]
ContentEncoder::Identity(transfer), ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
}
} }
} }
fn streaming_encoding(buf: SharedBytes, version: Version, fn streaming_encoding(
resp: &mut HttpResponse) -> TransferEncoding { buf: SharedBytes, version: Version, resp: &mut HttpResponse
) -> TransferEncoding {
match resp.chunked() { match resp.chunked() {
Some(true) => { Some(true) => {
// Enable transfer encoding // Enable transfer encoding
@ -492,13 +507,12 @@ impl ContentEncoder {
resp.headers_mut().remove(TRANSFER_ENCODING); resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf) TransferEncoding::eof(buf)
} else { } else {
resp.headers_mut().insert( resp.headers_mut()
TRANSFER_ENCODING, HeaderValue::from_static("chunked")); .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
TransferEncoding::chunked(buf) TransferEncoding::chunked(buf)
} }
}, }
Some(false) => Some(false) => TransferEncoding::eof(buf),
TransferEncoding::eof(buf),
None => { None => {
// if Content-Length is specified, then use it as length hint // if Content-Length is specified, then use it as length hint
let (len, chunked) = let (len, chunked) =
@ -530,9 +544,11 @@ impl ContentEncoder {
match version { match version {
Version::HTTP_11 => { Version::HTTP_11 => {
resp.headers_mut().insert( resp.headers_mut().insert(
TRANSFER_ENCODING, HeaderValue::from_static("chunked")); TRANSFER_ENCODING,
HeaderValue::from_static("chunked"),
);
TransferEncoding::chunked(buf) TransferEncoding::chunked(buf)
}, }
_ => { _ => {
resp.headers_mut().remove(TRANSFER_ENCODING); resp.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf) TransferEncoding::eof(buf)
@ -545,11 +561,10 @@ impl ContentEncoder {
} }
impl ContentEncoder { impl ContentEncoder {
#[inline] #[inline]
pub fn is_eof(&self) -> bool { pub fn is_eof(&self) -> bool {
match *self { match *self {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(), ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(), ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(),
ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(), ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(),
@ -561,39 +576,35 @@ impl ContentEncoder {
#[inline(always)] #[inline(always)]
pub fn write_eof(&mut self) -> Result<(), io::Error> { pub fn write_eof(&mut self) -> Result<(), io::Error> {
let encoder = mem::replace( let encoder = mem::replace(
self, ContentEncoder::Identity(TransferEncoding::eof(SharedBytes::empty()))); self,
ContentEncoder::Identity(TransferEncoding::eof(SharedBytes::empty())),
);
match encoder { match encoder {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
ContentEncoder::Br(encoder) => { ContentEncoder::Br(encoder) => match encoder.finish() {
match encoder.finish() { Ok(mut writer) => {
Ok(mut writer) => { writer.encode_eof();
writer.encode_eof(); *self = ContentEncoder::Identity(writer);
*self = ContentEncoder::Identity(writer); Ok(())
Ok(())
},
Err(err) => Err(err),
}
}
ContentEncoder::Gzip(encoder) => {
match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
},
Err(err) => Err(err),
} }
Err(err) => Err(err),
}, },
ContentEncoder::Deflate(encoder) => { ContentEncoder::Gzip(encoder) => match encoder.finish() {
match encoder.finish() { Ok(mut writer) => {
Ok(mut writer) => { writer.encode_eof();
writer.encode_eof(); *self = ContentEncoder::Identity(writer);
*self = ContentEncoder::Identity(writer); Ok(())
Ok(())
},
Err(err) => Err(err),
} }
Err(err) => Err(err),
},
ContentEncoder::Deflate(encoder) => match encoder.finish() {
Ok(mut writer) => {
writer.encode_eof();
*self = ContentEncoder::Identity(writer);
Ok(())
}
Err(err) => Err(err),
}, },
ContentEncoder::Identity(mut writer) => { ContentEncoder::Identity(mut writer) => {
writer.encode_eof(); writer.encode_eof();
@ -607,23 +618,23 @@ impl ContentEncoder {
#[inline(always)] #[inline(always)]
pub fn write(&mut self, data: Binary) -> Result<(), io::Error> { pub fn write(&mut self, data: Binary) -> Result<(), io::Error> {
match *self { match *self {
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => { ContentEncoder::Br(ref mut encoder) => {
match encoder.write_all(data.as_ref()) { match encoder.write_all(data.as_ref()) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {
trace!("Error decoding br encoding: {}", err); trace!("Error decoding br encoding: {}", err);
Err(err) Err(err)
}, }
} }
}, }
ContentEncoder::Gzip(ref mut encoder) => { ContentEncoder::Gzip(ref mut encoder) => {
match encoder.write_all(data.as_ref()) { match encoder.write_all(data.as_ref()) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {
trace!("Error decoding gzip encoding: {}", err); trace!("Error decoding gzip encoding: {}", err);
Err(err) Err(err)
}, }
} }
} }
ContentEncoder::Deflate(ref mut encoder) => { ContentEncoder::Deflate(ref mut encoder) => {
@ -632,7 +643,7 @@ impl ContentEncoder {
Err(err) => { Err(err) => {
trace!("Error decoding deflate encoding: {}", err); trace!("Error decoding deflate encoding: {}", err);
Err(err) Err(err)
}, }
} }
} }
ContentEncoder::Identity(ref mut encoder) => { ContentEncoder::Identity(ref mut encoder) => {
@ -665,7 +676,6 @@ enum TransferEncodingKind {
} }
impl TransferEncoding { impl TransferEncoding {
#[inline] #[inline]
pub fn eof(bytes: SharedBytes) -> TransferEncoding { pub fn eof(bytes: SharedBytes) -> TransferEncoding {
TransferEncoding { TransferEncoding {
@ -707,7 +717,7 @@ impl TransferEncoding {
let eof = msg.is_empty(); let eof = msg.is_empty();
self.buffer.extend(msg); self.buffer.extend(msg);
Ok(eof) Ok(eof)
}, }
TransferEncodingKind::Chunked(ref mut eof) => { TransferEncodingKind::Chunked(ref mut eof) => {
if *eof { if *eof {
return Ok(true); return Ok(true);
@ -726,21 +736,22 @@ impl TransferEncoding {
self.buffer.extend_from_slice(b"\r\n"); self.buffer.extend_from_slice(b"\r\n");
} }
Ok(*eof) Ok(*eof)
}, }
TransferEncodingKind::Length(ref mut remaining) => { TransferEncodingKind::Length(ref mut remaining) => {
if *remaining > 0 { if *remaining > 0 {
if msg.is_empty() { if msg.is_empty() {
return Ok(*remaining == 0) return Ok(*remaining == 0);
} }
let len = cmp::min(*remaining, msg.len() as u64); let len = cmp::min(*remaining, msg.len() as u64);
self.buffer.extend(msg.take().split_to(len as usize).into()); self.buffer
.extend(msg.take().split_to(len as usize).into());
*remaining -= len as u64; *remaining -= len as u64;
Ok(*remaining == 0) Ok(*remaining == 0)
} else { } else {
Ok(true) Ok(true)
} }
}, }
} }
} }
@ -754,13 +765,12 @@ impl TransferEncoding {
*eof = true; *eof = true;
self.buffer.extend_from_slice(b"0\r\n\r\n"); self.buffer.extend_from_slice(b"0\r\n\r\n");
} }
}, }
} }
} }
} }
impl io::Write for TransferEncoding { impl io::Write for TransferEncoding {
#[inline] #[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.encode(Binary::from_slice(buf))?; self.encode(Binary::from_slice(buf))?;
@ -773,7 +783,6 @@ impl io::Write for TransferEncoding {
} }
} }
struct AcceptEncoding { struct AcceptEncoding {
encoding: ContentEncoding, encoding: ContentEncoding,
quality: f64, quality: f64,
@ -817,27 +826,31 @@ impl AcceptEncoding {
_ => match f64::from_str(parts[1]) { _ => match f64::from_str(parts[1]) {
Ok(q) => q, Ok(q) => q,
Err(_) => 0.0, Err(_) => 0.0,
} },
}; };
Some(AcceptEncoding{ encoding, quality }) Some(AcceptEncoding {
encoding,
quality,
})
} }
/// Parse a raw Accept-Encoding header value into an ordered list. /// Parse a raw Accept-Encoding header value into an ordered list.
pub fn parse(raw: &str) -> ContentEncoding { pub fn parse(raw: &str) -> ContentEncoding {
let mut encodings: Vec<_> = let mut encodings: Vec<_> = raw.replace(' ', "")
raw.replace(' ', "").split(',').map(|l| AcceptEncoding::new(l)).collect(); .split(',')
.map(|l| AcceptEncoding::new(l))
.collect();
encodings.sort(); encodings.sort();
for enc in encodings { for enc in encodings {
if let Some(enc) = enc { if let Some(enc) = enc {
return enc.encoding return enc.encoding;
} }
} }
ContentEncoding::Identity ContentEncoding::Identity
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -846,9 +859,13 @@ mod tests {
fn test_chunked_te() { fn test_chunked_te() {
let bytes = SharedBytes::default(); let bytes = SharedBytes::default();
let mut enc = TransferEncoding::chunked(bytes.clone()); let mut enc = TransferEncoding::chunked(bytes.clone());
assert!(!enc.encode(Binary::from(b"test".as_ref())).ok().unwrap()); assert!(!enc.encode(Binary::from(b"test".as_ref()))
.ok()
.unwrap());
assert!(enc.encode(Binary::from(b"".as_ref())).ok().unwrap()); assert!(enc.encode(Binary::from(b"".as_ref())).ok().unwrap());
assert_eq!(bytes.get_mut().take().freeze(), assert_eq!(
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")); bytes.get_mut().take().freeze(),
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +1,22 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] #![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::{io, mem};
use std::rc::Rc;
use bytes::BufMut; use bytes::BufMut;
use futures::{Async, Poll}; use futures::{Async, Poll};
use tokio_io::AsyncWrite;
use http::{Method, Version};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE};
use http::{Method, Version};
use std::rc::Rc;
use std::{io, mem};
use tokio_io::AsyncWrite;
use body::{Body, Binary}; use super::encoding::ContentEncoder;
use super::helpers;
use super::settings::WorkerSettings;
use super::shared::SharedBytes;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
use body::{Binary, Body};
use header::ContentEncoding; use header::ContentEncoding;
use httprequest::HttpInnerMessage; use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use super::helpers;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
use super::shared::SharedBytes;
use super::encoding::ContentEncoder;
use super::settings::WorkerSettings;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
@ -41,10 +41,9 @@ pub(crate) struct H1Writer<T: AsyncWrite, H: 'static> {
} }
impl<T: AsyncWrite, H: 'static> H1Writer<T, H> { impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
pub fn new(
pub fn new(stream: T, buf: SharedBytes, settings: Rc<WorkerSettings<H>>) stream: T, buf: SharedBytes, settings: Rc<WorkerSettings<H>>
-> H1Writer<T, H> ) -> H1Writer<T, H> {
{
H1Writer { H1Writer {
flags: Flags::empty(), flags: Flags::empty(),
encoder: ContentEncoder::empty(buf.clone()), encoder: ContentEncoder::empty(buf.clone()),
@ -80,11 +79,11 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
match self.stream.write(&data[written..]) { match self.stream.write(&data[written..]) {
Ok(0) => { Ok(0) => {
self.disconnected(); self.disconnected();
return Err(io::Error::new(io::ErrorKind::WriteZero, "")) return Err(io::Error::new(io::ErrorKind::WriteZero, ""));
}, }
Ok(n) => { Ok(n) => {
written += n; written += n;
}, }
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(written) return Ok(written)
} }
@ -96,19 +95,18 @@ impl<T: AsyncWrite, H: 'static> H1Writer<T, H> {
} }
impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> { impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline] #[inline]
fn written(&self) -> u64 { fn written(&self) -> u64 {
self.written self.written
} }
fn start(&mut self, fn start(
req: &mut HttpInnerMessage, &mut self, req: &mut HttpInnerMessage, msg: &mut HttpResponse,
msg: &mut HttpResponse, encoding: ContentEncoding,
encoding: ContentEncoding) -> io::Result<WriterState> ) -> io::Result<WriterState> {
{
// prepare task // prepare task
self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding); self.encoder =
ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding);
if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) {
self.flags.insert(Flags::STARTED | Flags::KEEPALIVE); self.flags.insert(Flags::STARTED | Flags::KEEPALIVE);
} else { } else {
@ -119,15 +117,18 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
let version = msg.version().unwrap_or_else(|| req.version); let version = msg.version().unwrap_or_else(|| req.version);
if msg.upgrade() { if msg.upgrade() {
self.flags.insert(Flags::UPGRADE); self.flags.insert(Flags::UPGRADE);
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); msg.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
} }
// keep-alive // keep-alive
else if self.flags.contains(Flags::KEEPALIVE) { else if self.flags.contains(Flags::KEEPALIVE) {
if version < Version::HTTP_11 { if version < Version::HTTP_11 {
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("keep-alive")); msg.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
} }
} else if version >= Version::HTTP_11 { } else if version >= Version::HTTP_11 {
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("close")); msg.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("close"));
} }
let body = msg.replace_body(Body::Empty); let body = msg.replace_body(Body::Empty);
@ -137,12 +138,14 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
let reason = msg.reason().as_bytes(); let reason = msg.reason().as_bytes();
let mut is_bin = if let Body::Binary(ref bytes) = body { let mut is_bin = if let Body::Binary(ref bytes) = body {
buffer.reserve( buffer.reserve(
256 + msg.headers().len() * AVERAGE_HEADER_SIZE 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len()
+ bytes.len() + reason.len()); + reason.len(),
);
true true
} else { } else {
buffer.reserve( buffer.reserve(
256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len()); 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len(),
);
false false
}; };
@ -151,51 +154,50 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
SharedBytes::extend_from_slice_(buffer, reason); SharedBytes::extend_from_slice_(buffer, reason);
match body { match body {
Body::Empty => Body::Empty => if req.method != Method::HEAD {
if req.method != Method::HEAD { SharedBytes::put_slice(buffer, b"\r\ncontent-length: 0\r\n");
SharedBytes::put_slice(buffer, b"\r\ncontent-length: 0\r\n"); } else {
} else { SharedBytes::put_slice(buffer, b"\r\n");
SharedBytes::put_slice(buffer, b"\r\n"); },
}, Body::Binary(ref bytes) => {
Body::Binary(ref bytes) => helpers::write_content_length(bytes.len(), &mut buffer)
helpers::write_content_length(bytes.len(), &mut buffer), }
_ => _ => SharedBytes::put_slice(buffer, b"\r\n"),
SharedBytes::put_slice(buffer, b"\r\n"),
} }
// write headers // write headers
let mut pos = 0; let mut pos = 0;
let mut has_date = false; let mut has_date = false;
let mut remaining = buffer.remaining_mut(); let mut remaining = buffer.remaining_mut();
let mut buf: &mut [u8] = unsafe{ mem::transmute(buffer.bytes_mut()) }; let mut buf: &mut [u8] = unsafe { mem::transmute(buffer.bytes_mut()) };
for (key, value) in msg.headers() { for (key, value) in msg.headers() {
if is_bin && key == CONTENT_LENGTH { if is_bin && key == CONTENT_LENGTH {
is_bin = false; is_bin = false;
continue continue;
} }
has_date = has_date || key == DATE; has_date = has_date || key == DATE;
let v = value.as_ref(); let v = value.as_ref();
let k = key.as_str().as_bytes(); let k = key.as_str().as_bytes();
let len = k.len() + v.len() + 4; let len = k.len() + v.len() + 4;
if len > remaining { if len > remaining {
unsafe{buffer.advance_mut(pos)}; unsafe { buffer.advance_mut(pos) };
pos = 0; pos = 0;
buffer.reserve(len); buffer.reserve(len);
remaining = buffer.remaining_mut(); remaining = buffer.remaining_mut();
buf = unsafe{ mem::transmute(buffer.bytes_mut()) }; buf = unsafe { mem::transmute(buffer.bytes_mut()) };
} }
buf[pos..pos+k.len()].copy_from_slice(k); buf[pos..pos + k.len()].copy_from_slice(k);
pos += k.len(); pos += k.len();
buf[pos..pos+2].copy_from_slice(b": "); buf[pos..pos + 2].copy_from_slice(b": ");
pos += 2; pos += 2;
buf[pos..pos+v.len()].copy_from_slice(v); buf[pos..pos + v.len()].copy_from_slice(v);
pos += v.len(); pos += v.len();
buf[pos..pos+2].copy_from_slice(b"\r\n"); buf[pos..pos + 2].copy_from_slice(b"\r\n");
pos += 2; pos += 2;
remaining -= len; remaining -= len;
} }
unsafe{buffer.advance_mut(pos)}; unsafe { buffer.advance_mut(pos) };
// optimized date header, set_date writes \r\n // optimized date header, set_date writes \r\n
if !has_date { if !has_date {
@ -256,8 +258,10 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
self.encoder.write_eof()?; self.encoder.write_eof()?;
if !self.encoder.is_eof() { if !self.encoder.is_eof() {
Err(io::Error::new(io::ErrorKind::Other, Err(io::Error::new(
"Last payload item, but eof is not reached")) io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
} else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
Ok(WriterState::Pause) Ok(WriterState::Pause)
} else { } else {
@ -268,11 +272,11 @@ impl<T: AsyncWrite, H: 'static> Writer for H1Writer<T, H> {
#[inline] #[inline]
fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> {
if !self.buffer.is_empty() { if !self.buffer.is_empty() {
let buf: &[u8] = unsafe{mem::transmute(self.buffer.as_ref())}; let buf: &[u8] = unsafe { mem::transmute(self.buffer.as_ref()) };
let written = self.write_data(buf)?; let written = self.write_data(buf)?;
let _ = self.buffer.split_to(written); let _ = self.buffer.split_to(written);
if self.buffer.len() > self.buffer_capacity { if self.buffer.len() > self.buffer_capacity {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
} }
if shutdown { if shutdown {

View File

@ -1,30 +1,30 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] #![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::{io, cmp, mem};
use std::rc::Rc;
use std::io::{Read, Write};
use std::time::Duration;
use std::net::SocketAddr;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::io::{Read, Write};
use std::net::SocketAddr;
use std::rc::Rc;
use std::time::Duration;
use std::{cmp, io, mem};
use actix::Arbiter; use actix::Arbiter;
use modhttp::request::Parts;
use http2::{Reason, RecvStream};
use http2::server::{self, Connection, Handshake, SendResponse};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures::{Async, Poll, Future, Stream}; use futures::{Async, Future, Poll, Stream};
use tokio_io::{AsyncRead, AsyncWrite}; use http2::server::{self, Connection, Handshake, SendResponse};
use http2::{Reason, RecvStream};
use modhttp::request::Parts;
use tokio_core::reactor::Timeout; use tokio_core::reactor::Timeout;
use tokio_io::{AsyncRead, AsyncWrite};
use pipeline::Pipeline;
use error::PayloadError; use error::PayloadError;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use payload::{Payload, PayloadWriter, PayloadStatus}; use payload::{Payload, PayloadStatus, PayloadWriter};
use pipeline::Pipeline;
use super::h2writer::H2Writer;
use super::encoding::PayloadType; use super::encoding::PayloadType;
use super::h2writer::H2Writer;
use super::settings::WorkerSettings; use super::settings::WorkerSettings;
use super::{HttpHandler, HttpHandlerTask, Writer}; use super::{HttpHandler, HttpHandlerTask, Writer};
@ -35,9 +35,10 @@ bitflags! {
} }
/// HTTP/2 Transport /// HTTP/2 Transport
pub(crate) pub(crate) struct Http2<T, H>
struct Http2<T, H> where
where T: AsyncRead + AsyncWrite + 'static, H: 'static T: AsyncRead + AsyncWrite + 'static,
H: 'static,
{ {
flags: Flags, flags: Flags,
settings: Rc<WorkerSettings<H>>, settings: Rc<WorkerSettings<H>>,
@ -54,20 +55,23 @@ enum State<T: AsyncRead + AsyncWrite> {
} }
impl<T, H> Http2<T, H> impl<T, H> Http2<T, H>
where T: AsyncRead + AsyncWrite + 'static, where
H: HttpHandler + 'static T: AsyncRead + AsyncWrite + 'static,
H: HttpHandler + 'static,
{ {
pub fn new(settings: Rc<WorkerSettings<H>>, pub fn new(
io: T, settings: Rc<WorkerSettings<H>>, io: T, addr: Option<SocketAddr>, buf: Bytes
addr: Option<SocketAddr>, buf: Bytes) -> Self ) -> Self {
{ Http2 {
Http2{ flags: Flags::empty(), flags: Flags::empty(),
tasks: VecDeque::new(), tasks: VecDeque::new(),
state: State::Handshake( state: State::Handshake(server::handshake(IoWrapper {
server::handshake(IoWrapper{unread: Some(buf), inner: io})), unread: Some(buf),
keepalive_timer: None, inner: io,
addr, })),
settings, keepalive_timer: None,
addr,
settings,
} }
} }
@ -89,7 +93,7 @@ impl<T, H> Http2<T, H>
match timeout.poll() { match timeout.poll() {
Ok(Async::Ready(_)) => { Ok(Async::Ready(_)) => {
trace!("Keep-alive timeout, close connection"); trace!("Keep-alive timeout, close connection");
return Ok(Async::Ready(())) return Ok(Async::Ready(()));
} }
Ok(Async::NotReady) => (), Ok(Async::NotReady) => (),
Err(_) => unreachable!(), Err(_) => unreachable!(),
@ -111,29 +115,30 @@ impl<T, H> Http2<T, H>
Ok(Async::Ready(ready)) => { Ok(Async::Ready(ready)) => {
if ready { if ready {
item.flags.insert( item.flags.insert(
EntryFlags::EOF | EntryFlags::FINISHED); EntryFlags::EOF | EntryFlags::FINISHED,
);
} else { } else {
item.flags.insert(EntryFlags::EOF); item.flags.insert(EntryFlags::EOF);
} }
not_ready = false; not_ready = false;
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
if item.payload.need_read() == PayloadStatus::Read if item.payload.need_read() == PayloadStatus::Read
&& !retry && !retry
{ {
continue continue;
} }
}, }
Err(err) => { Err(err) => {
error!("Unhandled error: {}", err); error!("Unhandled error: {}", err);
item.flags.insert( item.flags.insert(
EntryFlags::EOF | EntryFlags::EOF | EntryFlags::ERROR
EntryFlags::ERROR | | EntryFlags::WRITE_DONE,
EntryFlags::WRITE_DONE); );
item.stream.reset(Reason::INTERNAL_ERROR); item.stream.reset(Reason::INTERNAL_ERROR);
} }
} }
break break;
} }
} else if !item.flags.contains(EntryFlags::FINISHED) { } else if !item.flags.contains(EntryFlags::FINISHED) {
match item.task.poll() { match item.task.poll() {
@ -141,11 +146,12 @@ impl<T, H> Http2<T, H>
Ok(Async::Ready(_)) => { Ok(Async::Ready(_)) => {
not_ready = false; not_ready = false;
item.flags.insert(EntryFlags::FINISHED); item.flags.insert(EntryFlags::FINISHED);
}, }
Err(err) => { Err(err) => {
item.flags.insert( item.flags.insert(
EntryFlags::ERROR | EntryFlags::WRITE_DONE | EntryFlags::ERROR | EntryFlags::WRITE_DONE
EntryFlags::FINISHED); | EntryFlags::FINISHED,
);
error!("Unhandled error: {}", err); error!("Unhandled error: {}", err);
} }
} }
@ -167,13 +173,13 @@ impl<T, H> Http2<T, H>
// cleanup finished tasks // cleanup finished tasks
while !self.tasks.is_empty() { while !self.tasks.is_empty() {
if self.tasks[0].flags.contains(EntryFlags::EOF) && if self.tasks[0].flags.contains(EntryFlags::EOF)
self.tasks[0].flags.contains(EntryFlags::WRITE_DONE) || && self.tasks[0].flags.contains(EntryFlags::WRITE_DONE)
self.tasks[0].flags.contains(EntryFlags::ERROR) || self.tasks[0].flags.contains(EntryFlags::ERROR)
{ {
self.tasks.pop_front(); self.tasks.pop_front();
} else { } else {
break break;
} }
} }
@ -186,7 +192,7 @@ impl<T, H> Http2<T, H>
for entry in &mut self.tasks { for entry in &mut self.tasks {
entry.task.disconnected() entry.task.disconnected()
} }
}, }
Ok(Async::Ready(Some((req, resp)))) => { Ok(Async::Ready(Some((req, resp)))) => {
not_ready = false; not_ready = false;
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
@ -194,8 +200,13 @@ impl<T, H> Http2<T, H>
// stop keepalive timer // stop keepalive timer
self.keepalive_timer.take(); self.keepalive_timer.take();
self.tasks.push_back( self.tasks.push_back(Entry::new(
Entry::new(parts, body, resp, self.addr, &self.settings)); parts,
body,
resp,
self.addr,
&self.settings,
));
} }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
// start keep-alive timer // start keep-alive timer
@ -213,12 +224,13 @@ impl<T, H> Http2<T, H>
} }
} else { } else {
// keep-alive disable, drop connection // keep-alive disable, drop connection
return conn.poll_close().map_err( return conn.poll_close().map_err(|e| {
|e| error!("Error during connection close: {}", e)) error!("Error during connection close: {}", e)
});
} }
} else { } else {
// keep-alive unset, rely on operating system // keep-alive unset, rely on operating system
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
} }
Err(err) => { Err(err) => {
@ -228,16 +240,17 @@ impl<T, H> Http2<T, H>
entry.task.disconnected() entry.task.disconnected()
} }
self.keepalive_timer.take(); self.keepalive_timer.take();
}, }
} }
} }
if not_ready { if not_ready {
if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) { if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED)
return conn.poll_close().map_err( {
|e| error!("Error during connection close: {}", e)) return conn.poll_close()
.map_err(|e| error!("Error during connection close: {}", e));
} else { } else {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
} }
} }
@ -246,14 +259,11 @@ impl<T, H> Http2<T, H>
// handshake // handshake
self.state = if let State::Handshake(ref mut handshake) = self.state { self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() { match handshake.poll() {
Ok(Async::Ready(conn)) => { Ok(Async::Ready(conn)) => State::Connection(conn),
State::Connection(conn) Ok(Async::NotReady) => return Ok(Async::NotReady),
},
Ok(Async::NotReady) =>
return Ok(Async::NotReady),
Err(err) => { Err(err) => {
trace!("Error handling connection: {}", err); trace!("Error handling connection: {}", err);
return Err(()) return Err(());
} }
} }
} else { } else {
@ -283,12 +293,12 @@ struct Entry<H: 'static> {
} }
impl<H: 'static> Entry<H> { impl<H: 'static> Entry<H> {
fn new(parts: Parts, fn new(
recv: RecvStream, parts: Parts, recv: RecvStream, resp: SendResponse<Bytes>,
resp: SendResponse<Bytes>, addr: Option<SocketAddr>, settings: &Rc<WorkerSettings<H>>,
addr: Option<SocketAddr>, ) -> Entry<H>
settings: &Rc<WorkerSettings<H>>) -> Entry<H> where
where H: HttpHandler + 'static H: HttpHandler + 'static,
{ {
// Payload and Content-Encoding // Payload and Content-Encoding
let (psender, payload) = Payload::new(false); let (psender, payload) = Payload::new(false);
@ -312,18 +322,22 @@ impl<H: 'static> Entry<H> {
req = match h.handle(req) { req = match h.handle(req) {
Ok(t) => { Ok(t) => {
task = Some(t); task = Some(t);
break break;
}, }
Err(req) => req, Err(req) => req,
} }
} }
Entry {task: task.unwrap_or_else(|| Pipeline::error(HttpResponse::NotFound())), Entry {
payload: psender, task: task.unwrap_or_else(|| Pipeline::error(HttpResponse::NotFound())),
stream: H2Writer::new( payload: psender,
resp, settings.get_shared_bytes(), Rc::clone(settings)), stream: H2Writer::new(
flags: EntryFlags::empty(), resp,
recv, settings.get_shared_bytes(),
Rc::clone(settings),
),
flags: EntryFlags::empty(),
recv,
} }
} }
@ -340,14 +354,12 @@ impl<H: 'static> Entry<H> {
match self.recv.poll() { match self.recv.poll() {
Ok(Async::Ready(Some(chunk))) => { Ok(Async::Ready(Some(chunk))) => {
self.payload.feed_data(chunk); self.payload.feed_data(chunk);
}, }
Ok(Async::Ready(None)) => { Ok(Async::Ready(None)) => {
self.flags.insert(EntryFlags::REOF); self.flags.insert(EntryFlags::REOF);
},
Ok(Async::NotReady) => (),
Err(err) => {
self.payload.set_error(PayloadError::Http2(err))
} }
Ok(Async::NotReady) => (),
Err(err) => self.payload.set_error(PayloadError::Http2(err)),
} }
} }
} }

View File

@ -1,25 +1,25 @@
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] #![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
use std::{io, cmp};
use std::rc::Rc;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Poll}; use futures::{Async, Poll};
use http2::{Reason, SendStream};
use http2::server::SendResponse; use http2::server::SendResponse;
use http2::{Reason, SendStream};
use modhttp::Response; use modhttp::Response;
use std::rc::Rc;
use std::{cmp, io};
use http::{Version, HttpTryFrom}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::header::{HeaderValue, CONNECTION, TRANSFER_ENCODING, DATE, CONTENT_LENGTH}; use http::{HttpTryFrom, Version};
use body::{Body, Binary}; use super::encoding::ContentEncoder;
use super::helpers;
use super::settings::WorkerSettings;
use super::shared::SharedBytes;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
use body::{Binary, Body};
use header::ContentEncoding; use header::ContentEncoding;
use httprequest::HttpInnerMessage; use httprequest::HttpInnerMessage;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use super::helpers;
use super::encoding::ContentEncoder;
use super::shared::SharedBytes;
use super::settings::WorkerSettings;
use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE};
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
@ -44,10 +44,9 @@ pub(crate) struct H2Writer<H: 'static> {
} }
impl<H: 'static> H2Writer<H> { impl<H: 'static> H2Writer<H> {
pub fn new(
pub fn new(respond: SendResponse<Bytes>, respond: SendResponse<Bytes>, buf: SharedBytes, settings: Rc<WorkerSettings<H>>
buf: SharedBytes, settings: Rc<WorkerSettings<H>>) -> H2Writer<H> ) -> H2Writer<H> {
{
H2Writer { H2Writer {
respond, respond,
settings, settings,
@ -68,19 +67,18 @@ impl<H: 'static> H2Writer<H> {
} }
impl<H: 'static> Writer for H2Writer<H> { impl<H: 'static> Writer for H2Writer<H> {
fn written(&self) -> u64 { fn written(&self) -> u64 {
self.written self.written
} }
fn start(&mut self, fn start(
req: &mut HttpInnerMessage, &mut self, req: &mut HttpInnerMessage, msg: &mut HttpResponse,
msg: &mut HttpResponse, encoding: ContentEncoding,
encoding: ContentEncoding) -> io::Result<WriterState> ) -> io::Result<WriterState> {
{
// prepare response // prepare response
self.flags.insert(Flags::STARTED); self.flags.insert(Flags::STARTED);
self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding); self.encoder =
ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding);
if let Body::Empty = *msg.body() { if let Body::Empty = *msg.body() {
self.flags.insert(Flags::EOF); self.flags.insert(Flags::EOF);
} }
@ -93,7 +91,8 @@ impl<H: 'static> Writer for H2Writer<H> {
if !msg.headers().contains_key(DATE) { if !msg.headers().contains_key(DATE) {
let mut bytes = BytesMut::with_capacity(29); let mut bytes = BytesMut::with_capacity(29);
self.settings.set_date_simple(&mut bytes); self.settings.set_date_simple(&mut bytes);
msg.headers_mut().insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); msg.headers_mut()
.insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap());
} }
let body = msg.replace_body(Body::Empty); let body = msg.replace_body(Body::Empty);
@ -104,11 +103,13 @@ impl<H: 'static> Writer for H2Writer<H> {
let l = val.len(); let l = val.len();
msg.headers_mut().insert( msg.headers_mut().insert(
CONTENT_LENGTH, CONTENT_LENGTH,
HeaderValue::try_from(val.split_to(l-2).freeze()).unwrap()); HeaderValue::try_from(val.split_to(l - 2).freeze()).unwrap(),
);
} }
Body::Empty => { Body::Empty => {
msg.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("0")); msg.headers_mut()
}, .insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
}
_ => (), _ => (),
} }
@ -119,11 +120,11 @@ impl<H: 'static> Writer for H2Writer<H> {
resp.headers_mut().insert(key, value.clone()); resp.headers_mut().insert(key, value.clone());
} }
match self.respond.send_response(resp, self.flags.contains(Flags::EOF)) { match self.respond
Ok(stream) => .send_response(resp, self.flags.contains(Flags::EOF))
self.stream = Some(stream), {
Err(_) => Ok(stream) => self.stream = Some(stream),
return Err(io::Error::new(io::ErrorKind::Other, "err")), Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "err")),
} }
trace!("Response: {:?}", msg); trace!("Response: {:?}", msg);
@ -169,8 +170,10 @@ impl<H: 'static> Writer for H2Writer<H> {
self.flags.insert(Flags::EOF); self.flags.insert(Flags::EOF);
if !self.encoder.is_eof() { if !self.encoder.is_eof() {
Err(io::Error::new(io::ErrorKind::Other, Err(io::Error::new(
"Last payload item, but eof is not reached")) io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
} else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
Ok(WriterState::Pause) Ok(WriterState::Pause)
} else { } else {
@ -197,17 +200,18 @@ impl<H: 'static> Writer for H2Writer<H> {
Ok(Async::Ready(Some(cap))) => { Ok(Async::Ready(Some(cap))) => {
let len = self.buffer.len(); let len = self.buffer.len();
let bytes = self.buffer.split_to(cmp::min(cap, len)); let bytes = self.buffer.split_to(cmp::min(cap, len));
let eof = self.buffer.is_empty() && self.flags.contains(Flags::EOF); let eof =
self.buffer.is_empty() && self.flags.contains(Flags::EOF);
self.written += bytes.len() as u64; self.written += bytes.len() as u64;
if let Err(e) = stream.send_data(bytes.freeze(), eof) { if let Err(e) = stream.send_data(bytes.freeze(), eof) {
return Err(io::Error::new(io::ErrorKind::Other, e)) return Err(io::Error::new(io::ErrorKind::Other, e));
} else if !self.buffer.is_empty() { } else if !self.buffer.is_empty() {
let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); let cap = cmp::min(self.buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap); stream.reserve_capacity(cap);
} else { } else {
self.flags.remove(Flags::RESERVED); self.flags.remove(Flags::RESERVED);
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
} }
Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),

View File

@ -1,9 +1,9 @@
use std::{mem, ptr, slice};
use std::cell::RefCell;
use std::rc::Rc;
use std::collections::VecDeque;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use http::Version; use http::Version;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;
use std::{mem, ptr, slice};
use httprequest::HttpInnerMessage; use httprequest::HttpInnerMessage;
@ -35,7 +35,9 @@ impl SharedMessagePool {
} }
pub(crate) struct SharedHttpInnerMessage( pub(crate) struct SharedHttpInnerMessage(
Option<Rc<HttpInnerMessage>>, Option<Rc<SharedMessagePool>>); Option<Rc<HttpInnerMessage>>,
Option<Rc<SharedMessagePool>>,
);
impl Drop for SharedHttpInnerMessage { impl Drop for SharedHttpInnerMessage {
fn drop(&mut self) { fn drop(&mut self) {
@ -50,26 +52,25 @@ impl Drop for SharedHttpInnerMessage {
} }
impl Clone for SharedHttpInnerMessage { impl Clone for SharedHttpInnerMessage {
fn clone(&self) -> SharedHttpInnerMessage { fn clone(&self) -> SharedHttpInnerMessage {
SharedHttpInnerMessage(self.0.clone(), self.1.clone()) SharedHttpInnerMessage(self.0.clone(), self.1.clone())
} }
} }
impl Default for SharedHttpInnerMessage { impl Default for SharedHttpInnerMessage {
fn default() -> SharedHttpInnerMessage { fn default() -> SharedHttpInnerMessage {
SharedHttpInnerMessage(Some(Rc::new(HttpInnerMessage::default())), None) SharedHttpInnerMessage(Some(Rc::new(HttpInnerMessage::default())), None)
} }
} }
impl SharedHttpInnerMessage { impl SharedHttpInnerMessage {
pub fn from_message(msg: HttpInnerMessage) -> SharedHttpInnerMessage { pub fn from_message(msg: HttpInnerMessage) -> SharedHttpInnerMessage {
SharedHttpInnerMessage(Some(Rc::new(msg)), None) SharedHttpInnerMessage(Some(Rc::new(msg)), None)
} }
pub fn new(msg: Rc<HttpInnerMessage>, pool: Rc<SharedMessagePool>) -> SharedHttpInnerMessage { pub fn new(
msg: Rc<HttpInnerMessage>, pool: Rc<SharedMessagePool>
) -> SharedHttpInnerMessage {
SharedHttpInnerMessage(Some(msg), Some(pool)) SharedHttpInnerMessage(Some(msg), Some(pool))
} }
@ -78,7 +79,7 @@ impl SharedHttpInnerMessage {
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))]
pub fn get_mut(&self) -> &mut HttpInnerMessage { pub fn get_mut(&self) -> &mut HttpInnerMessage {
let r: &HttpInnerMessage = self.0.as_ref().unwrap().as_ref(); let r: &HttpInnerMessage = self.0.as_ref().unwrap().as_ref();
unsafe{mem::transmute(r)} unsafe { mem::transmute(r) }
} }
#[inline(always)] #[inline(always)]
@ -88,20 +89,23 @@ impl SharedHttpInnerMessage {
} }
} }
const DEC_DIGITS_LUT: &[u8] = const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\
b"0001020304050607080910111213141516171819\
2021222324252627282930313233343536373839\ 2021222324252627282930313233343536373839\
4041424344454647484950515253545556575859\ 4041424344454647484950515253545556575859\
6061626364656667686970717273747576777879\ 6061626364656667686970717273747576777879\
8081828384858687888990919293949596979899"; 8081828384858687888990919293949596979899";
pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) { pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) {
let mut buf: [u8; 13] = [b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', let mut buf: [u8; 13] = [
b' ', b' ', b' ', b' ', b' ']; b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', b' ', b' ', b' ', b' ', b' '
];
match version { match version {
Version::HTTP_2 => buf[5] = b'2', Version::HTTP_2 => buf[5] = b'2',
Version::HTTP_10 => buf[7] = b'0', Version::HTTP_10 => buf[7] = b'0',
Version::HTTP_09 => {buf[5] = b'0'; buf[7] = b'9';}, Version::HTTP_09 => {
buf[5] = b'0';
buf[7] = b'9';
}
_ => (), _ => (),
} }
@ -124,7 +128,11 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM
} else { } else {
let d1 = n << 1; let d1 = n << 1;
curr -= 2; curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); ptr::copy_nonoverlapping(
lut_ptr.offset(d1 as isize),
buf_ptr.offset(curr),
2,
);
} }
} }
@ -137,30 +145,41 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM
/// NOTE: bytes object has to contain enough space /// NOTE: bytes object has to contain enough space
pub(crate) fn write_content_length(mut n: usize, bytes: &mut BytesMut) { pub(crate) fn write_content_length(mut n: usize, bytes: &mut BytesMut) {
if n < 10 { if n < 10 {
let mut buf: [u8; 21] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e', let mut buf: [u8; 21] = [
b'n',b't',b'-',b'l',b'e',b'n',b'g', b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e',
b't',b'h',b':',b' ',b'0',b'\r',b'\n']; b'n', b'g', b't', b'h', b':', b' ', b'0', b'\r', b'\n',
];
buf[18] = (n as u8) + b'0'; buf[18] = (n as u8) + b'0';
bytes.put_slice(&buf); bytes.put_slice(&buf);
} else if n < 100 { } else if n < 100 {
let mut buf: [u8; 22] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e', let mut buf: [u8; 22] = [
b'n',b't',b'-',b'l',b'e',b'n',b'g', b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e',
b't',b'h',b':',b' ',b'0',b'0',b'\r',b'\n']; b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'\r', b'\n',
];
let d1 = n << 1; let d1 = n << 1;
unsafe { unsafe {
ptr::copy_nonoverlapping( ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), buf.as_mut_ptr().offset(18), 2); DEC_DIGITS_LUT.as_ptr().offset(d1 as isize),
buf.as_mut_ptr().offset(18),
2,
);
} }
bytes.put_slice(&buf); bytes.put_slice(&buf);
} else if n < 1000 { } else if n < 1000 {
let mut buf: [u8; 23] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e', let mut buf: [u8; 23] = [
b'n',b't',b'-',b'l',b'e',b'n',b'g', b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e',
b't',b'h',b':',b' ',b'0',b'0',b'0',b'\r',b'\n']; b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'0', b'\r', b'\n',
];
// decode 2 more chars, if > 2 chars // decode 2 more chars, if > 2 chars
let d1 = (n % 100) << 1; let d1 = (n % 100) << 1;
n /= 100; n /= 100;
unsafe {ptr::copy_nonoverlapping( unsafe {
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), buf.as_mut_ptr().offset(19), 2)}; ptr::copy_nonoverlapping(
DEC_DIGITS_LUT.as_ptr().offset(d1 as isize),
buf.as_mut_ptr().offset(19),
2,
)
};
// decode last 1 // decode last 1
buf[18] = (n as u8) + b'0'; buf[18] = (n as u8) + b'0';
@ -216,12 +235,13 @@ pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) {
} }
unsafe { unsafe {
bytes.extend_from_slice( bytes.extend_from_slice(slice::from_raw_parts(
slice::from_raw_parts(buf_ptr.offset(curr), 41 - curr as usize)); buf_ptr.offset(curr),
41 - curr as usize,
));
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -231,33 +251,63 @@ mod tests {
let mut bytes = BytesMut::new(); let mut bytes = BytesMut::new();
bytes.reserve(50); bytes.reserve(50);
write_content_length(0, &mut bytes); write_content_length(0, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 0\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(9, &mut bytes); write_content_length(9, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 9\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(10, &mut bytes); write_content_length(10, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 10\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(99, &mut bytes); write_content_length(99, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 99\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(100, &mut bytes); write_content_length(100, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 100\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(101, &mut bytes); write_content_length(101, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 101\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(998, &mut bytes); write_content_length(998, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 998\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(1000, &mut bytes); write_content_length(1000, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 1000\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(1001, &mut bytes); write_content_length(1001, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 1001\r\n"[..]
);
bytes.reserve(50); bytes.reserve(50);
write_content_length(5909, &mut bytes); write_content_length(5909, &mut bytes);
assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); assert_eq!(
bytes.take().freeze(),
b"\r\ncontent-length: 5909\r\n"[..]
);
} }
} }

View File

@ -1,27 +1,27 @@
//! Http server //! Http server
use std::{time, io};
use std::net::Shutdown; use std::net::Shutdown;
use std::{io, time};
use actix; use actix;
use futures::Poll; use futures::Poll;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_core::net::TcpStream; use tokio_core::net::TcpStream;
use tokio_io::{AsyncRead, AsyncWrite};
mod srv;
mod worker;
mod channel; mod channel;
pub(crate) mod encoding; pub(crate) mod encoding;
pub(crate) mod h1; pub(crate) mod h1;
mod h2;
mod h1writer; mod h1writer;
mod h2;
mod h2writer; mod h2writer;
mod settings;
pub(crate) mod helpers; pub(crate) mod helpers;
mod settings;
pub(crate) mod shared; pub(crate) mod shared;
mod srv;
pub(crate) mod utils; pub(crate) mod utils;
mod worker;
pub use self::srv::HttpServer;
pub use self::settings::ServerSettings; pub use self::settings::ServerSettings;
pub use self::srv::HttpServer;
use body::Binary; use body::Binary;
use error::Error; use error::Error;
@ -56,9 +56,10 @@ pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536;
/// } /// }
/// ``` /// ```
pub fn new<F, U, H>(factory: F) -> HttpServer<H> pub fn new<F, U, H>(factory: F) -> HttpServer<H>
where F: Fn() -> U + Sync + Send + 'static, where
U: IntoIterator<Item=H> + 'static, F: Fn() -> U + Sync + Send + 'static,
H: IntoHttpHandler + 'static U: IntoIterator<Item = H> + 'static,
H: IntoHttpHandler + 'static,
{ {
HttpServer::new(factory) HttpServer::new(factory)
} }
@ -107,7 +108,7 @@ pub struct ResumeServer;
/// ///
/// If server starts with `spawn()` method, then spawned thread get terminated. /// If server starts with `spawn()` method, then spawned thread get terminated.
pub struct StopServer { pub struct StopServer {
pub graceful: bool pub graceful: bool,
} }
impl actix::Message for StopServer { impl actix::Message for StopServer {
@ -117,7 +118,6 @@ impl actix::Message for StopServer {
/// Low level http request handler /// Low level http request handler
#[allow(unused_variables)] #[allow(unused_variables)]
pub trait HttpHandler: 'static { pub trait HttpHandler: 'static {
/// Handle request /// Handle request
fn handle(&mut self, req: HttpRequest) -> Result<Box<HttpHandlerTask>, HttpRequest>; fn handle(&mut self, req: HttpRequest) -> Result<Box<HttpHandlerTask>, HttpRequest>;
} }
@ -130,7 +130,6 @@ impl HttpHandler for Box<HttpHandler> {
#[doc(hidden)] #[doc(hidden)]
pub trait HttpHandlerTask { pub trait HttpHandlerTask {
/// Poll task, this method is used before or after *io* object is available /// Poll task, this method is used before or after *io* object is available
fn poll(&mut self) -> Poll<(), Error>; fn poll(&mut self) -> Poll<(), Error>;
@ -170,8 +169,10 @@ pub enum WriterState {
pub trait Writer { pub trait Writer {
fn written(&self) -> u64; fn written(&self) -> u64;
fn start(&mut self, req: &mut HttpInnerMessage, resp: &mut HttpResponse, encoding: ContentEncoding) fn start(
-> io::Result<WriterState>; &mut self, req: &mut HttpInnerMessage, resp: &mut HttpResponse,
encoding: ContentEncoding,
) -> io::Result<WriterState>;
fn write(&mut self, payload: Binary) -> io::Result<WriterState>; fn write(&mut self, payload: Binary) -> io::Result<WriterState>;
@ -207,10 +208,10 @@ impl IoStream for TcpStream {
} }
} }
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use tokio_openssl::SslStream; use tokio_openssl::SslStream;
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
impl IoStream for SslStream<TcpStream> { impl IoStream for SslStream<TcpStream> {
#[inline] #[inline]
fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {
@ -229,10 +230,10 @@ impl IoStream for SslStream<TcpStream> {
} }
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
use tokio_tls::TlsStream; use tokio_tls::TlsStream;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
impl IoStream for TlsStream<TcpStream> { impl IoStream for TlsStream<TcpStream> {
#[inline] #[inline]
fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {

View File

@ -1,19 +1,19 @@
use std::{fmt, mem, net}; use bytes::BytesMut;
use futures_cpupool::{Builder, CpuPool};
use http::StatusCode;
use std::cell::{Cell, RefCell, RefMut, UnsafeCell};
use std::fmt::Write; use std::fmt::Write;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::cell::{Cell, RefCell, RefMut, UnsafeCell}; use std::{fmt, mem, net};
use time; use time;
use bytes::BytesMut;
use http::StatusCode;
use futures_cpupool::{Builder, CpuPool};
use super::helpers;
use super::KeepAlive; use super::KeepAlive;
use super::channel::Node; use super::channel::Node;
use super::helpers;
use super::shared::{SharedBytes, SharedBytesPool}; use super::shared::{SharedBytes, SharedBytesPool};
use body::Body; use body::Body;
use httpresponse::{HttpResponse, HttpResponsePool, HttpResponseBuilder}; use httpresponse::{HttpResponse, HttpResponseBuilder, HttpResponsePool};
/// Various server settings /// Various server settings
#[derive(Clone)] #[derive(Clone)]
@ -71,9 +71,9 @@ impl Default for ServerSettings {
impl ServerSettings { impl ServerSettings {
/// Crate server settings instance /// Crate server settings instance
pub(crate) fn new(addr: Option<net::SocketAddr>, host: &Option<String>, secure: bool) pub(crate) fn new(
-> ServerSettings addr: Option<net::SocketAddr>, host: &Option<String>, secure: bool
{ ) -> ServerSettings {
let host = if let Some(ref host) = *host { let host = if let Some(ref host) = *host {
host.clone() host.clone()
} else if let Some(ref addr) = addr { } else if let Some(ref addr) = addr {
@ -83,7 +83,13 @@ impl ServerSettings {
}; };
let cpu_pool = Arc::new(InnerCpuPool::new()); let cpu_pool = Arc::new(InnerCpuPool::new());
let responses = HttpResponsePool::pool(); let responses = HttpResponsePool::pool();
ServerSettings { addr, secure, host, cpu_pool, responses } ServerSettings {
addr,
secure,
host,
cpu_pool,
responses,
}
} }
/// Returns the socket address of the local half of this TCP connection /// Returns the socket address of the local half of this TCP connection
@ -112,12 +118,13 @@ impl ServerSettings {
} }
#[inline] #[inline]
pub(crate) fn get_response_builder(&self, status: StatusCode) -> HttpResponseBuilder { pub(crate) fn get_response_builder(
&self, status: StatusCode
) -> HttpResponseBuilder {
HttpResponsePool::get_builder(&self.responses, status) HttpResponsePool::get_builder(&self.responses, status)
} }
} }
// "Sun, 06 Nov 1994 08:49:37 GMT".len() // "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29; const DATE_VALUE_LENGTH: usize = 29;
@ -141,7 +148,8 @@ impl<H> WorkerSettings<H> {
}; };
WorkerSettings { WorkerSettings {
keep_alive, ka_enabled, keep_alive,
ka_enabled,
h: RefCell::new(h), h: RefCell::new(h),
bytes: Rc::new(SharedBytesPool::new()), bytes: Rc::new(SharedBytesPool::new()),
messages: Rc::new(helpers::SharedMessagePool::new()), messages: Rc::new(helpers::SharedMessagePool::new()),
@ -176,7 +184,10 @@ impl<H> WorkerSettings<H> {
} }
pub fn get_http_message(&self) -> helpers::SharedHttpInnerMessage { pub fn get_http_message(&self) -> helpers::SharedHttpInnerMessage {
helpers::SharedHttpInnerMessage::new(self.messages.get(), Rc::clone(&self.messages)) helpers::SharedHttpInnerMessage::new(
self.messages.get(),
Rc::clone(&self.messages),
)
} }
pub fn add_channel(&self) { pub fn add_channel(&self) {
@ -186,26 +197,26 @@ impl<H> WorkerSettings<H> {
pub fn remove_channel(&self) { pub fn remove_channel(&self) {
let num = self.channels.get(); let num = self.channels.get();
if num > 0 { if num > 0 {
self.channels.set(num-1); self.channels.set(num - 1);
} else { } else {
error!("Number of removed channels is bigger than added channel. Bug in actix-web"); error!("Number of removed channels is bigger than added channel. Bug in actix-web");
} }
} }
pub fn update_date(&self) { pub fn update_date(&self) {
unsafe{&mut *self.date.get()}.update(); unsafe { &mut *self.date.get() }.update();
} }
pub fn set_date(&self, dst: &mut BytesMut) { pub fn set_date(&self, dst: &mut BytesMut) {
let mut buf: [u8; 39] = unsafe { mem::uninitialized() }; let mut buf: [u8; 39] = unsafe { mem::uninitialized() };
buf[..6].copy_from_slice(b"date: "); buf[..6].copy_from_slice(b"date: ");
buf[6..35].copy_from_slice(&(unsafe{&*self.date.get()}.bytes)); buf[6..35].copy_from_slice(&(unsafe { &*self.date.get() }.bytes));
buf[35..].copy_from_slice(b"\r\n\r\n"); buf[35..].copy_from_slice(b"\r\n\r\n");
dst.extend_from_slice(&buf); dst.extend_from_slice(&buf);
} }
pub fn set_date_simple(&self, dst: &mut BytesMut) { pub fn set_date_simple(&self, dst: &mut BytesMut) {
dst.extend_from_slice(&(unsafe{&*self.date.get()}.bytes)); dst.extend_from_slice(&(unsafe { &*self.date.get() }.bytes));
} }
} }
@ -216,7 +227,10 @@ struct Date {
impl Date { impl Date {
fn new() -> Date { fn new() -> Date {
let mut date = Date{bytes: [0; DATE_VALUE_LENGTH], pos: 0}; let mut date = Date {
bytes: [0; DATE_VALUE_LENGTH],
pos: 0,
};
date.update(); date.update();
date date
} }
@ -235,14 +249,16 @@ impl fmt::Write for Date {
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_date_len() { fn test_date_len() {
assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); assert_eq!(
DATE_VALUE_LENGTH,
"Sun, 06 Nov 1994 08:49:37 GMT".len()
);
} }
#[test] #[test]

View File

@ -1,12 +1,11 @@
use std::{io, mem};
use std::cell::RefCell;
use std::rc::Rc;
use std::collections::VecDeque;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::Rc;
use std::{io, mem};
use body::Binary; use body::Binary;
/// Internal use only! unsafe /// Internal use only! unsafe
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct SharedBytesPool(RefCell<VecDeque<Rc<BytesMut>>>); pub(crate) struct SharedBytesPool(RefCell<VecDeque<Rc<BytesMut>>>);
@ -34,8 +33,7 @@ impl SharedBytesPool {
} }
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct SharedBytes( pub(crate) struct SharedBytes(Option<Rc<BytesMut>>, Option<Rc<SharedBytesPool>>);
Option<Rc<BytesMut>>, Option<Rc<SharedBytesPool>>);
impl Drop for SharedBytes { impl Drop for SharedBytes {
fn drop(&mut self) { fn drop(&mut self) {
@ -50,7 +48,6 @@ impl Drop for SharedBytes {
} }
impl SharedBytes { impl SharedBytes {
pub fn empty() -> Self { pub fn empty() -> Self {
SharedBytes(None, None) SharedBytes(None, None)
} }
@ -64,7 +61,7 @@ impl SharedBytes {
#[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))]
pub(crate) fn get_mut(&self) -> &mut BytesMut { pub(crate) fn get_mut(&self) -> &mut BytesMut {
let r: &BytesMut = self.0.as_ref().unwrap().as_ref(); let r: &BytesMut = self.0.as_ref().unwrap().as_ref();
unsafe{mem::transmute(r)} unsafe { mem::transmute(r) }
} }
#[inline] #[inline]

View File

@ -1,31 +1,33 @@
use std::{io, net, thread};
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, mpsc as sync_mpsc}; use std::sync::{mpsc as sync_mpsc, Arc};
use std::time::Duration; use std::time::Duration;
use std::{io, net, thread};
use actix::prelude::*;
use actix::actors::signal; use actix::actors::signal;
use futures::{Future, Sink, Stream}; use actix::prelude::*;
use futures::sync::mpsc; use futures::sync::mpsc;
use tokio_io::{AsyncRead, AsyncWrite}; use futures::{Future, Sink, Stream};
use mio; use mio;
use num_cpus;
use net2::TcpBuilder; use net2::TcpBuilder;
use num_cpus;
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature="tls")] #[cfg(feature = "tls")]
use native_tls::TlsAcceptor; use native_tls::TlsAcceptor;
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use openssl::ssl::{AlpnError, SslAcceptorBuilder}; use openssl::ssl::{AlpnError, SslAcceptorBuilder};
use super::channel::{HttpChannel, WrapperStream};
use super::settings::{ServerSettings, WorkerSettings};
use super::worker::{Conn, StopWorker, StreamHandlerType, Worker};
use super::{IntoHttpHandler, IoStream, KeepAlive}; use super::{IntoHttpHandler, IoStream, KeepAlive};
use super::{PauseServer, ResumeServer, StopServer}; use super::{PauseServer, ResumeServer, StopServer};
use super::channel::{HttpChannel, WrapperStream};
use super::worker::{Conn, Worker, StreamHandlerType, StopWorker};
use super::settings::{ServerSettings, WorkerSettings};
/// An HTTP Server /// An HTTP Server
pub struct HttpServer<H> where H: IntoHttpHandler + 'static pub struct HttpServer<H>
where
H: IntoHttpHandler + 'static,
{ {
h: Option<Rc<WorkerSettings<H::Handler>>>, h: Option<Rc<WorkerSettings<H::Handler>>>,
threads: usize, threads: usize,
@ -33,7 +35,7 @@ pub struct HttpServer<H> where H: IntoHttpHandler + 'static
host: Option<String>, host: Option<String>,
keep_alive: KeepAlive, keep_alive: KeepAlive,
factory: Arc<Fn() -> Vec<H> + Send + Sync>, factory: Arc<Fn() -> Vec<H> + Send + Sync>,
#[cfg_attr(feature="cargo-clippy", allow(type_complexity))] #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))]
workers: Vec<(usize, Addr<Syn, Worker<H::Handler>>)>, workers: Vec<(usize, Addr<Syn, Worker<H::Handler>>)>,
sockets: Vec<(net::SocketAddr, net::TcpListener)>, sockets: Vec<(net::SocketAddr, net::TcpListener)>,
accept: Vec<(mio::SetReadiness, sync_mpsc::Sender<Command>)>, accept: Vec<(mio::SetReadiness, sync_mpsc::Sender<Command>)>,
@ -44,8 +46,16 @@ pub struct HttpServer<H> where H: IntoHttpHandler + 'static
no_signals: bool, no_signals: bool,
} }
unsafe impl<H> Sync for HttpServer<H> where H: IntoHttpHandler {} unsafe impl<H> Sync for HttpServer<H>
unsafe impl<H> Send for HttpServer<H> where H: IntoHttpHandler {} where
H: IntoHttpHandler,
{
}
unsafe impl<H> Send for HttpServer<H>
where
H: IntoHttpHandler,
{
}
#[derive(Clone)] #[derive(Clone)]
struct Info { struct Info {
@ -57,41 +67,47 @@ enum ServerCommand {
WorkerDied(usize, Info), WorkerDied(usize, Info),
} }
impl<H> Actor for HttpServer<H> where H: IntoHttpHandler { impl<H> Actor for HttpServer<H>
where
H: IntoHttpHandler,
{
type Context = Context<Self>; type Context = Context<Self>;
} }
impl<H> HttpServer<H> where H: IntoHttpHandler + 'static impl<H> HttpServer<H>
where
H: IntoHttpHandler + 'static,
{ {
/// Create new http server with application factory /// Create new http server with application factory
pub fn new<F, U>(factory: F) -> Self pub fn new<F, U>(factory: F) -> Self
where F: Fn() -> U + Sync + Send + 'static, where
U: IntoIterator<Item=H> + 'static, F: Fn() -> U + Sync + Send + 'static,
U: IntoIterator<Item = H> + 'static,
{ {
let f = move || { let f = move || (factory)().into_iter().collect();
(factory)().into_iter().collect()
};
HttpServer{ h: None, HttpServer {
threads: num_cpus::get(), h: None,
backlog: 2048, threads: num_cpus::get(),
host: None, backlog: 2048,
keep_alive: KeepAlive::Os, host: None,
factory: Arc::new(f), keep_alive: KeepAlive::Os,
workers: Vec::new(), factory: Arc::new(f),
sockets: Vec::new(), workers: Vec::new(),
accept: Vec::new(), sockets: Vec::new(),
exit: false, accept: Vec::new(),
shutdown_timeout: 30, exit: false,
signals: None, shutdown_timeout: 30,
no_http2: false, signals: None,
no_signals: false, no_http2: false,
no_signals: false,
} }
} }
/// Set number of workers to start. /// Set number of workers to start.
/// ///
/// By default http server uses number of available logical cpu as threads count. /// By default http server uses number of available logical cpu as threads
/// count.
pub fn threads(mut self, num: usize) -> Self { pub fn threads(mut self, num: usize) -> Self {
self.threads = num; self.threads = num;
self self
@ -101,7 +117,8 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
/// ///
/// This refers to the number of clients that can be waiting to be served. /// This refers to the number of clients that can be waiting to be served.
/// Exceeding this number results in the client getting an error when /// Exceeding this number results in the client getting an error when
/// attempting to connect. It should only affect servers under significant load. /// attempting to connect. It should only affect servers under significant
/// load.
/// ///
/// Generally set in the 64-2048 range. Default value is 2048. /// Generally set in the 64-2048 range. Default value is 2048.
/// ///
@ -121,9 +138,9 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
/// Set server host name. /// Set server host name.
/// ///
/// Host name is used by application router aa a hostname for url generation. /// Host name is used by application router aa a hostname for url
/// Check [ConnectionInfo](./dev/struct.ConnectionInfo.html#method.host) documentation /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo.
/// for more information. /// html#method.host) documentation for more information.
pub fn server_hostname(mut self, val: String) -> Self { pub fn server_hostname(mut self, val: String) -> Self {
self.host = Some(val); self.host = Some(val);
self self
@ -152,8 +169,9 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
/// Timeout for graceful workers shutdown. /// Timeout for graceful workers shutdown.
/// ///
/// After receiving a stop signal, workers have this much time to finish serving requests. /// After receiving a stop signal, workers have this much time to finish
/// Workers still alive after the timeout are force dropped. /// serving requests. Workers still alive after the timeout are force
/// dropped.
/// ///
/// By default shutdown timeout sets to 30 seconds. /// By default shutdown timeout sets to 30 seconds.
pub fn shutdown_timeout(mut self, sec: u16) -> Self { pub fn shutdown_timeout(mut self, sec: u16) -> Self {
@ -192,7 +210,7 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
Ok(lst) => { Ok(lst) => {
succ = true; succ = true;
self.sockets.push((lst.local_addr().unwrap(), lst)); self.sockets.push((lst.local_addr().unwrap(), lst));
}, }
Err(e) => err = Some(e), Err(e) => err = Some(e),
} }
} }
@ -201,16 +219,19 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
if let Some(e) = err.take() { if let Some(e) = err.take() {
Err(e) Err(e)
} else { } else {
Err(io::Error::new(io::ErrorKind::Other, "Can not bind to address.")) Err(io::Error::new(
io::ErrorKind::Other,
"Can not bind to address.",
))
} }
} else { } else {
Ok(self) Ok(self)
} }
} }
fn start_workers(&mut self, settings: &ServerSettings, handler: &StreamHandlerType) fn start_workers(
-> Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)> &mut self, settings: &ServerSettings, handler: &StreamHandlerType
{ ) -> Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)> {
// start workers // start workers
let mut workers = Vec::new(); let mut workers = Vec::new();
for idx in 0..self.threads { for idx in 0..self.threads {
@ -223,7 +244,8 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
let addr = Arbiter::start(move |ctx: &mut Context<_>| { let addr = Arbiter::start(move |ctx: &mut Context<_>| {
let apps: Vec<_> = (*factory)() let apps: Vec<_> = (*factory)()
.into_iter() .into_iter()
.map(|h| h.into_handler(s.clone())).collect(); .map(|h| h.into_handler(s.clone()))
.collect();
ctx.add_message_stream(rx); ctx.add_message_stream(rx);
Worker::new(apps, h, ka) Worker::new(apps, h, ka)
}); });
@ -248,12 +270,12 @@ impl<H> HttpServer<H> where H: IntoHttpHandler + 'static
} }
} }
impl<H: IntoHttpHandler> HttpServer<H> impl<H: IntoHttpHandler> HttpServer<H> {
{
/// Start listening for incoming connections. /// Start listening for incoming connections.
/// ///
/// This method starts number of http handler workers in separate threads. /// This method starts number of http handler workers in separate threads.
/// For each address this method starts separate thread which does `accept()` in a loop. /// For each address this method starts separate thread which does
/// `accept()` in a loop.
/// ///
/// This methods panics if no socket addresses get bound. /// This methods panics if no socket addresses get bound.
/// ///
@ -277,8 +299,7 @@ impl<H: IntoHttpHandler> HttpServer<H>
/// let _ = sys.run(); // <- Run actix system, this method actually starts all async processes /// let _ = sys.run(); // <- Run actix system, this method actually starts all async processes
/// } /// }
/// ``` /// ```
pub fn start(mut self) -> Addr<Syn, Self> pub fn start(mut self) -> Addr<Syn, Self> {
{
if self.sockets.is_empty() { if self.sockets.is_empty() {
panic!("HttpServer::bind() has to be called before start()"); panic!("HttpServer::bind() has to be called before start()");
} else { } else {
@ -287,15 +308,22 @@ impl<H: IntoHttpHandler> HttpServer<H>
self.sockets.drain(..).collect(); self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers(&settings, &StreamHandlerType::Normal); let workers = self.start_workers(&settings, &StreamHandlerType::Normal);
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Normal}; let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Normal,
};
// start acceptors threads // start acceptors threads
for (addr, sock) in addrs { for (addr, sock) in addrs {
info!("Starting server on http://{}", addr); info!("Starting server on http://{}", addr);
self.accept.push( self.accept.push(start_accept_thread(
start_accept_thread( sock,
sock, addr, self.backlog, addr,
tx.clone(), info.clone(), workers.clone())); self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
} }
// start http server actor // start http server actor
@ -304,16 +332,17 @@ impl<H: IntoHttpHandler> HttpServer<H>
ctx.add_stream(rx); ctx.add_stream(rx);
self self
}); });
signals.map(|signals| signals.do_send( signals.map(|signals| {
signal::Subscribe(addr.clone().recipient()))); signals.do_send(signal::Subscribe(addr.clone().recipient()))
});
addr addr
} }
} }
/// Spawn new thread and start listening for incoming connections. /// Spawn new thread and start listening for incoming connections.
/// ///
/// This method spawns new thread and starts new actix system. Other than that it is /// This method spawns new thread and starts new actix system. Other than
/// similar to `start()` method. This method blocks. /// that it is similar to `start()` method. This method blocks.
/// ///
/// This methods panics if no socket addresses get bound. /// This methods panics if no socket addresses get bound.
/// ///
@ -344,28 +373,38 @@ impl<H: IntoHttpHandler> HttpServer<H>
} }
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
impl<H: IntoHttpHandler> HttpServer<H> impl<H: IntoHttpHandler> HttpServer<H> {
{
/// Start listening for incoming tls connections. /// Start listening for incoming tls connections.
pub fn start_tls(mut self, acceptor: TlsAcceptor) -> io::Result<Addr<Syn, Self>> { pub fn start_tls(mut self, acceptor: TlsAcceptor) -> io::Result<Addr<Syn, Self>> {
if self.sockets.is_empty() { if self.sockets.is_empty() {
Err(io::Error::new(io::ErrorKind::Other, "No socket addresses are bound")) Err(io::Error::new(
io::ErrorKind::Other,
"No socket addresses are bound",
))
} else { } else {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
let addrs: Vec<(net::SocketAddr, net::TcpListener)> = self.sockets.drain(..).collect(); let addrs: Vec<(net::SocketAddr, net::TcpListener)> =
self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers( let workers =
&settings, &StreamHandlerType::Tls(acceptor.clone())); self.start_workers(&settings, &StreamHandlerType::Tls(acceptor.clone()));
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Tls(acceptor)}; let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Tls(acceptor),
};
// start acceptors threads // start acceptors threads
for (addr, sock) in addrs { for (addr, sock) in addrs {
info!("Starting server on https://{}", addr); info!("Starting server on https://{}", addr);
self.accept.push( self.accept.push(start_accept_thread(
start_accept_thread( sock,
sock, addr, self.backlog, addr,
tx.clone(), info.clone(), workers.clone())); self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
} }
// start http server actor // start http server actor
@ -374,23 +413,27 @@ impl<H: IntoHttpHandler> HttpServer<H>
ctx.add_stream(rx); ctx.add_stream(rx);
self self
}); });
signals.map(|signals| signals.do_send( signals.map(|signals| {
signal::Subscribe(addr.clone().recipient()))); signals.do_send(signal::Subscribe(addr.clone().recipient()))
});
Ok(addr) Ok(addr)
} }
} }
} }
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
impl<H: IntoHttpHandler> HttpServer<H> impl<H: IntoHttpHandler> HttpServer<H> {
{
/// Start listening for incoming tls connections. /// Start listening for incoming tls connections.
/// ///
/// This method sets alpn protocols to "h2" and "http/1.1" /// This method sets alpn protocols to "h2" and "http/1.1"
pub fn start_ssl(mut self, mut builder: SslAcceptorBuilder) -> io::Result<Addr<Syn, Self>> pub fn start_ssl(
{ mut self, mut builder: SslAcceptorBuilder
) -> io::Result<Addr<Syn, Self>> {
if self.sockets.is_empty() { if self.sockets.is_empty() {
Err(io::Error::new(io::ErrorKind::Other, "No socket addresses are bound")) Err(io::Error::new(
io::ErrorKind::Other,
"No socket addresses are bound",
))
} else { } else {
// alpn support // alpn support
if !self.no_http2 { if !self.no_http2 {
@ -407,19 +450,29 @@ impl<H: IntoHttpHandler> HttpServer<H>
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
let acceptor = builder.build(); let acceptor = builder.build();
let addrs: Vec<(net::SocketAddr, net::TcpListener)> = self.sockets.drain(..).collect(); let addrs: Vec<(net::SocketAddr, net::TcpListener)> =
self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers( let workers = self.start_workers(
&settings, &StreamHandlerType::Alpn(acceptor.clone())); &settings,
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Alpn(acceptor)}; &StreamHandlerType::Alpn(acceptor.clone()),
);
let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Alpn(acceptor),
};
// start acceptors threads // start acceptors threads
for (addr, sock) in addrs { for (addr, sock) in addrs {
info!("Starting server on https://{}", addr); info!("Starting server on https://{}", addr);
self.accept.push( self.accept.push(start_accept_thread(
start_accept_thread( sock,
sock, addr, self.backlog, addr,
tx.clone(), info.clone(), workers.clone())); self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
} }
// start http server actor // start http server actor
@ -428,22 +481,23 @@ impl<H: IntoHttpHandler> HttpServer<H>
ctx.add_stream(rx); ctx.add_stream(rx);
self self
}); });
signals.map(|signals| signals.do_send( signals.map(|signals| {
signal::Subscribe(addr.clone().recipient()))); signals.do_send(signal::Subscribe(addr.clone().recipient()))
});
Ok(addr) Ok(addr)
} }
} }
} }
impl<H: IntoHttpHandler> HttpServer<H> impl<H: IntoHttpHandler> HttpServer<H> {
{
/// Start listening for incoming connections from a stream. /// Start listening for incoming connections from a stream.
/// ///
/// This method uses only one thread for handling incoming connections. /// This method uses only one thread for handling incoming connections.
pub fn start_incoming<T, A, S>(mut self, stream: S, secure: bool) -> Addr<Syn, Self> pub fn start_incoming<T, A, S>(mut self, stream: S, secure: bool) -> Addr<Syn, Self>
where S: Stream<Item=(T, A), Error=io::Error> + 'static, where
T: AsyncRead + AsyncWrite + 'static, S: Stream<Item = (T, A), Error = io::Error> + 'static,
A: 'static T: AsyncRead + AsyncWrite + 'static,
A: 'static,
{ {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
@ -452,15 +506,22 @@ impl<H: IntoHttpHandler> HttpServer<H>
self.sockets.drain(..).collect(); self.sockets.drain(..).collect();
let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false);
let workers = self.start_workers(&settings, &StreamHandlerType::Normal); let workers = self.start_workers(&settings, &StreamHandlerType::Normal);
let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Normal}; let info = Info {
addr: addrs[0].0,
handler: StreamHandlerType::Normal,
};
// start acceptors threads // start acceptors threads
for (addr, sock) in addrs { for (addr, sock) in addrs {
info!("Starting server on http://{}", addr); info!("Starting server on http://{}", addr);
self.accept.push( self.accept.push(start_accept_thread(
start_accept_thread( sock,
sock, addr, self.backlog, addr,
tx.clone(), info.clone(), workers.clone())); self.backlog,
tx.clone(),
info.clone(),
workers.clone(),
));
} }
} }
@ -468,21 +529,24 @@ impl<H: IntoHttpHandler> HttpServer<H>
let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap(); let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
let settings = ServerSettings::new(Some(addr), &self.host, secure); let settings = ServerSettings::new(Some(addr), &self.host, secure);
let apps: Vec<_> = (*self.factory)() let apps: Vec<_> = (*self.factory)()
.into_iter().map(|h| h.into_handler(settings.clone())).collect(); .into_iter()
.map(|h| h.into_handler(settings.clone()))
.collect();
self.h = Some(Rc::new(WorkerSettings::new(apps, self.keep_alive))); self.h = Some(Rc::new(WorkerSettings::new(apps, self.keep_alive)));
// start server // start server
let signals = self.subscribe_to_signals(); let signals = self.subscribe_to_signals();
let addr: Addr<Syn, _> = HttpServer::create(move |ctx| { let addr: Addr<Syn, _> = HttpServer::create(move |ctx| {
ctx.add_stream(rx); ctx.add_stream(rx);
ctx.add_message_stream( ctx.add_message_stream(stream.map_err(|_| ()).map(move |(t, _)| Conn {
stream io: WrapperStream::new(t),
.map_err(|_| ()) peer: None,
.map(move |(t, _)| Conn{io: WrapperStream::new(t), peer: None, http2: false})); http2: false,
}));
self self
}); });
signals.map(|signals| signals.do_send( signals
signal::Subscribe(addr.clone().recipient()))); .map(|signals| signals.do_send(signal::Subscribe(addr.clone().recipient())));
addr addr
} }
} }
@ -490,8 +554,7 @@ impl<H: IntoHttpHandler> HttpServer<H>
/// Signals support /// Signals support
/// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and send `SystemExit(0)` /// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and send `SystemExit(0)`
/// message to `System` actor. /// message to `System` actor.
impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H> impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H> {
{
type Result = (); type Result = ();
fn handle(&mut self, msg: signal::Signal, ctx: &mut Context<Self>) { fn handle(&mut self, msg: signal::Signal, ctx: &mut Context<Self>) {
@ -499,17 +562,17 @@ impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H>
signal::SignalType::Int => { signal::SignalType::Int => {
info!("SIGINT received, exiting"); info!("SIGINT received, exiting");
self.exit = true; self.exit = true;
Handler::<StopServer>::handle(self, StopServer{graceful: false}, ctx); Handler::<StopServer>::handle(self, StopServer { graceful: false }, ctx);
} }
signal::SignalType::Term => { signal::SignalType::Term => {
info!("SIGTERM received, stopping"); info!("SIGTERM received, stopping");
self.exit = true; self.exit = true;
Handler::<StopServer>::handle(self, StopServer{graceful: true}, ctx); Handler::<StopServer>::handle(self, StopServer { graceful: true }, ctx);
} }
signal::SignalType::Quit => { signal::SignalType::Quit => {
info!("SIGQUIT received, exiting"); info!("SIGQUIT received, exiting");
self.exit = true; self.exit = true;
Handler::<StopServer>::handle(self, StopServer{graceful: false}, ctx); Handler::<StopServer>::handle(self, StopServer { graceful: false }, ctx);
} }
_ => (), _ => (),
} }
@ -517,8 +580,7 @@ impl<H: IntoHttpHandler> Handler<signal::Signal> for HttpServer<H>
} }
/// Commands from accept threads /// Commands from accept threads
impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H> impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H> {
{
fn finished(&mut self, _: &mut Context<Self>) {} fn finished(&mut self, _: &mut Context<Self>) {}
fn handle(&mut self, msg: ServerCommand, _: &mut Context<Self>) { fn handle(&mut self, msg: ServerCommand, _: &mut Context<Self>) {
match msg { match msg {
@ -528,7 +590,7 @@ impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
if self.workers[i].0 == idx { if self.workers[i].0 == idx {
self.workers.swap_remove(i); self.workers.swap_remove(i);
found = true; found = true;
break break;
} }
} }
@ -541,21 +603,23 @@ impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
for i in 0..self.workers.len() { for i in 0..self.workers.len() {
if self.workers[i].0 == new_idx { if self.workers[i].0 == new_idx {
new_idx += 1; new_idx += 1;
continue 'found continue 'found;
} }
} }
break break;
} }
let h = info.handler; let h = info.handler;
let ka = self.keep_alive; let ka = self.keep_alive;
let factory = Arc::clone(&self.factory); let factory = Arc::clone(&self.factory);
let settings = ServerSettings::new(Some(info.addr), &self.host, false); let settings =
ServerSettings::new(Some(info.addr), &self.host, false);
let addr = Arbiter::start(move |ctx: &mut Context<_>| { let addr = Arbiter::start(move |ctx: &mut Context<_>| {
let apps: Vec<_> = (*factory)() let apps: Vec<_> = (*factory)()
.into_iter() .into_iter()
.map(|h| h.into_handler(settings.clone())).collect(); .map(|h| h.into_handler(settings.clone()))
.collect();
ctx.add_message_stream(rx); ctx.add_message_stream(rx);
Worker::new(apps, h, ka) Worker::new(apps, h, ka)
}); });
@ -566,30 +630,32 @@ impl<H: IntoHttpHandler> StreamHandler<ServerCommand, ()> for HttpServer<H>
self.workers.push((new_idx, addr)); self.workers.push((new_idx, addr));
} }
}, }
} }
} }
} }
impl<T, H> Handler<Conn<T>> for HttpServer<H> impl<T, H> Handler<Conn<T>> for HttpServer<H>
where T: IoStream, where
H: IntoHttpHandler, T: IoStream,
H: IntoHttpHandler,
{ {
type Result = (); type Result = ();
fn handle(&mut self, msg: Conn<T>, _: &mut Context<Self>) -> Self::Result { fn handle(&mut self, msg: Conn<T>, _: &mut Context<Self>) -> Self::Result {
Arbiter::handle().spawn( Arbiter::handle().spawn(HttpChannel::new(
HttpChannel::new( Rc::clone(self.h.as_ref().unwrap()),
Rc::clone(self.h.as_ref().unwrap()), msg.io, msg.peer, msg.http2)); msg.io,
msg.peer,
msg.http2,
));
} }
} }
impl<H: IntoHttpHandler> Handler<PauseServer> for HttpServer<H> impl<H: IntoHttpHandler> Handler<PauseServer> for HttpServer<H> {
{
type Result = (); type Result = ();
fn handle(&mut self, _: PauseServer, _: &mut Context<Self>) fn handle(&mut self, _: PauseServer, _: &mut Context<Self>) {
{
for item in &self.accept { for item in &self.accept {
let _ = item.1.send(Command::Pause); let _ = item.1.send(Command::Pause);
let _ = item.0.set_readiness(mio::Ready::readable()); let _ = item.0.set_readiness(mio::Ready::readable());
@ -597,8 +663,7 @@ impl<H: IntoHttpHandler> Handler<PauseServer> for HttpServer<H>
} }
} }
impl<H: IntoHttpHandler> Handler<ResumeServer> for HttpServer<H> impl<H: IntoHttpHandler> Handler<ResumeServer> for HttpServer<H> {
{
type Result = (); type Result = ();
fn handle(&mut self, _: ResumeServer, _: &mut Context<Self>) { fn handle(&mut self, _: ResumeServer, _: &mut Context<Self>) {
@ -609,8 +674,7 @@ impl<H: IntoHttpHandler> Handler<ResumeServer> for HttpServer<H>
} }
} }
impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H> impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H> {
{
type Result = actix::Response<(), ()>; type Result = actix::Response<(), ()>;
fn handle(&mut self, msg: StopServer, ctx: &mut Context<Self>) -> Self::Result { fn handle(&mut self, msg: StopServer, ctx: &mut Context<Self>) -> Self::Result {
@ -630,7 +694,9 @@ impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H>
}; };
for worker in &self.workers { for worker in &self.workers {
let tx2 = tx.clone(); let tx2 = tx.clone();
worker.1.send(StopWorker{graceful: dur}) worker
.1
.send(StopWorker { graceful: dur })
.into_actor(self) .into_actor(self)
.then(move |_, slf, ctx| { .then(move |_, slf, ctx| {
slf.workers.pop(); slf.workers.pop();
@ -645,12 +711,12 @@ impl<H: IntoHttpHandler> Handler<StopServer> for HttpServer<H>
} }
} }
actix::fut::ok(()) actix::fut::ok(())
}).spawn(ctx); })
.spawn(ctx);
} }
if !self.workers.is_empty() { if !self.workers.is_empty() {
Response::async( Response::async(rx.into_future().map(|_| ()).map_err(|_| ()))
rx.into_future().map(|_| ()).map_err(|_| ()))
} else { } else {
// we need to stop system if server was spawned // we need to stop system if server was spawned
if self.exit { if self.exit {
@ -673,156 +739,184 @@ enum Command {
fn start_accept_thread( fn start_accept_thread(
sock: net::TcpListener, addr: net::SocketAddr, backlog: i32, sock: net::TcpListener, addr: net::SocketAddr, backlog: i32,
srv: mpsc::UnboundedSender<ServerCommand>, info: Info, srv: mpsc::UnboundedSender<ServerCommand>, info: Info,
mut workers: Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)>) mut workers: Vec<(usize, mpsc::UnboundedSender<Conn<net::TcpStream>>)>,
-> (mio::SetReadiness, sync_mpsc::Sender<Command>) ) -> (mio::SetReadiness, sync_mpsc::Sender<Command>) {
{
let (tx, rx) = sync_mpsc::channel(); let (tx, rx) = sync_mpsc::channel();
let (reg, readiness) = mio::Registration::new2(); let (reg, readiness) = mio::Registration::new2();
// start accept thread // start accept thread
#[cfg_attr(feature="cargo-clippy", allow(cyclomatic_complexity))] #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))]
let _ = thread::Builder::new().name(format!("Accept on {}", addr)).spawn(move || { let _ = thread::Builder::new()
const SRV: mio::Token = mio::Token(0); .name(format!("Accept on {}", addr))
const CMD: mio::Token = mio::Token(1); .spawn(move || {
const SRV: mio::Token = mio::Token(0);
const CMD: mio::Token = mio::Token(1);
let mut server = Some( let mut server = Some(
mio::net::TcpListener::from_std(sock) mio::net::TcpListener::from_std(sock)
.expect("Can not create mio::net::TcpListener")); .expect("Can not create mio::net::TcpListener"),
);
// Create a poll instance // Create a poll instance
let poll = match mio::Poll::new() { let poll = match mio::Poll::new() {
Ok(poll) => poll, Ok(poll) => poll,
Err(err) => panic!("Can not create mio::Poll: {}", err), Err(err) => panic!("Can not create mio::Poll: {}", err),
}; };
// Start listening for incoming connections // Start listening for incoming connections
if let Some(ref srv) = server { if let Some(ref srv) = server {
if let Err(err) = poll.register( if let Err(err) =
srv, SRV, mio::Ready::readable(), mio::PollOpt::edge()) { poll.register(srv, SRV, mio::Ready::readable(), mio::PollOpt::edge())
panic!("Can not register io: {}", err); {
} panic!("Can not register io: {}", err);
}
// Start listening for incoming commands
if let Err(err) = poll.register(&reg, CMD,
mio::Ready::readable(), mio::PollOpt::edge()) {
panic!("Can not register Registration: {}", err);
}
// Create storage for events
let mut events = mio::Events::with_capacity(128);
// Sleep on error
let sleep = Duration::from_millis(100);
let mut next = 0;
loop {
if let Err(err) = poll.poll(&mut events, None) {
panic!("Poll error: {}", err);
}
for event in events.iter() {
match event.token() {
SRV => if let Some(ref server) = server {
loop {
match server.accept_std() {
Ok((sock, addr)) => {
let mut msg = Conn{
io: sock, peer: Some(addr), http2: false};
while !workers.is_empty() {
match workers[next].1.unbounded_send(msg) {
Ok(_) => (),
Err(err) => {
let _ = srv.unbounded_send(
ServerCommand::WorkerDied(
workers[next].0, info.clone()));
msg = err.into_inner();
workers.swap_remove(next);
if workers.is_empty() {
error!("No workers");
thread::sleep(sleep);
break
} else if workers.len() <= next {
next = 0;
}
continue
}
}
next = (next + 1) % workers.len();
break
}
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock =>
break,
Err(ref e) if connection_error(e) =>
continue,
Err(e) => {
error!("Error accepting connection: {}", e);
// sleep after error
thread::sleep(sleep);
break
}
}
}
},
CMD => match rx.try_recv() {
Ok(cmd) => match cmd {
Command::Pause => if let Some(server) = server.take() {
if let Err(err) = poll.deregister(&server) {
error!("Can not deregister server socket {}", err);
} else {
info!("Paused accepting connections on {}", addr);
}
},
Command::Resume => {
let lst = create_tcp_listener(addr, backlog)
.expect("Can not create net::TcpListener");
server = Some(
mio::net::TcpListener::from_std(lst)
.expect("Can not create mio::net::TcpListener"));
if let Some(ref server) = server {
if let Err(err) = poll.register(
server, SRV, mio::Ready::readable(), mio::PollOpt::edge())
{
error!("Can not resume socket accept process: {}", err);
} else {
info!("Accepting connections on {} has been resumed",
addr);
}
}
},
Command::Stop => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return
},
Command::Worker(idx, addr) => {
workers.push((idx, addr));
},
},
Err(err) => match err {
sync_mpsc::TryRecvError::Empty => (),
sync_mpsc::TryRecvError::Disconnected => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return
},
}
},
_ => unreachable!(),
} }
} }
}
}); // Start listening for incoming commands
if let Err(err) = poll.register(
&reg,
CMD,
mio::Ready::readable(),
mio::PollOpt::edge(),
) {
panic!("Can not register Registration: {}", err);
}
// Create storage for events
let mut events = mio::Events::with_capacity(128);
// Sleep on error
let sleep = Duration::from_millis(100);
let mut next = 0;
loop {
if let Err(err) = poll.poll(&mut events, None) {
panic!("Poll error: {}", err);
}
for event in events.iter() {
match event.token() {
SRV => if let Some(ref server) = server {
loop {
match server.accept_std() {
Ok((sock, addr)) => {
let mut msg = Conn {
io: sock,
peer: Some(addr),
http2: false,
};
while !workers.is_empty() {
match workers[next].1.unbounded_send(msg) {
Ok(_) => (),
Err(err) => {
let _ = srv.unbounded_send(
ServerCommand::WorkerDied(
workers[next].0,
info.clone(),
),
);
msg = err.into_inner();
workers.swap_remove(next);
if workers.is_empty() {
error!("No workers");
thread::sleep(sleep);
break;
} else if workers.len() <= next {
next = 0;
}
continue;
}
}
next = (next + 1) % workers.len();
break;
}
}
Err(ref e)
if e.kind() == io::ErrorKind::WouldBlock =>
{
break
}
Err(ref e) if connection_error(e) => continue,
Err(e) => {
error!("Error accepting connection: {}", e);
// sleep after error
thread::sleep(sleep);
break;
}
}
}
},
CMD => match rx.try_recv() {
Ok(cmd) => match cmd {
Command::Pause => if let Some(server) = server.take() {
if let Err(err) = poll.deregister(&server) {
error!(
"Can not deregister server socket {}",
err
);
} else {
info!(
"Paused accepting connections on {}",
addr
);
}
},
Command::Resume => {
let lst = create_tcp_listener(addr, backlog)
.expect("Can not create net::TcpListener");
server = Some(
mio::net::TcpListener::from_std(lst).expect(
"Can not create mio::net::TcpListener",
),
);
if let Some(ref server) = server {
if let Err(err) = poll.register(
server,
SRV,
mio::Ready::readable(),
mio::PollOpt::edge(),
) {
error!("Can not resume socket accept process: {}", err);
} else {
info!("Accepting connections on {} has been resumed",
addr);
}
}
}
Command::Stop => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return;
}
Command::Worker(idx, addr) => {
workers.push((idx, addr));
}
},
Err(err) => match err {
sync_mpsc::TryRecvError::Empty => (),
sync_mpsc::TryRecvError::Disconnected => {
if let Some(server) = server.take() {
let _ = poll.deregister(&server);
}
return;
}
},
},
_ => unreachable!(),
}
}
}
});
(readiness, tx) (readiness, tx)
} }
fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result<net::TcpListener> { fn create_tcp_listener(
addr: net::SocketAddr, backlog: i32
) -> io::Result<net::TcpListener> {
let builder = match addr { let builder = match addr {
net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, net::SocketAddr::V4(_) => TcpBuilder::new_v4()?,
net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, net::SocketAddr::V6(_) => TcpBuilder::new_v6()?,
@ -840,7 +934,7 @@ fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result<net::T
/// The timeout is useful to handle resource exhaustion errors like ENFILE /// The timeout is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop. /// and EMFILE. Otherwise, could enter into tight loop.
fn connection_error(e: &io::Error) -> bool { fn connection_error(e: &io::Error) -> bool {
e.kind() == io::ErrorKind::ConnectionRefused || e.kind() == io::ErrorKind::ConnectionRefused
e.kind() == io::ErrorKind::ConnectionAborted || || e.kind() == io::ErrorKind::ConnectionAborted
e.kind() == io::ErrorKind::ConnectionReset || e.kind() == io::ErrorKind::ConnectionReset
} }

View File

@ -1,14 +1,15 @@
use std::io; use bytes::{BufMut, BytesMut};
use bytes::{BytesMut, BufMut};
use futures::{Async, Poll}; use futures::{Async, Poll};
use std::io;
use super::IoStream; use super::IoStream;
const LW_BUFFER_SIZE: usize = 4096; const LW_BUFFER_SIZE: usize = 4096;
const HW_BUFFER_SIZE: usize = 32_768; const HW_BUFFER_SIZE: usize = 32_768;
pub fn read_from_io<T: IoStream>(
pub fn read_from_io<T: IoStream>(io: &mut T, buf: &mut BytesMut) -> Poll<usize, io::Error> { io: &mut T, buf: &mut BytesMut
) -> Poll<usize, io::Error> {
unsafe { unsafe {
if buf.remaining_mut() < LW_BUFFER_SIZE { if buf.remaining_mut() < LW_BUFFER_SIZE {
buf.reserve(HW_BUFFER_SIZE); buf.reserve(HW_BUFFER_SIZE);
@ -17,7 +18,7 @@ pub fn read_from_io<T: IoStream>(io: &mut T, buf: &mut BytesMut) -> Poll<usize,
Ok(n) => { Ok(n) => {
buf.advance_mut(n); buf.advance_mut(n);
Ok(Async::Ready(n)) Ok(Async::Ready(n))
}, }
Err(e) => { Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock { if e.kind() == io::ErrorKind::WouldBlock {
Ok(Async::NotReady) Ok(Async::NotReady)

View File

@ -1,31 +1,30 @@
use std::{net, time};
use std::rc::Rc;
use futures::Future; use futures::Future;
use futures::unsync::oneshot; use futures::unsync::oneshot;
use net2::TcpStreamExt;
use std::rc::Rc;
use std::{net, time};
use tokio_core::net::TcpStream; use tokio_core::net::TcpStream;
use tokio_core::reactor::Handle; use tokio_core::reactor::Handle;
use net2::TcpStreamExt;
#[cfg(any(feature="tls", feature="alpn"))] #[cfg(any(feature = "tls", feature = "alpn"))]
use futures::future; use futures::future;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
use native_tls::TlsAcceptor; use native_tls::TlsAcceptor;
#[cfg(feature="tls")] #[cfg(feature = "tls")]
use tokio_tls::TlsAcceptorExt; use tokio_tls::TlsAcceptorExt;
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use openssl::ssl::SslAcceptor; use openssl::ssl::SslAcceptor;
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use tokio_openssl::SslAcceptorExt; use tokio_openssl::SslAcceptorExt;
use actix::*;
use actix::msgs::StopArbiter; use actix::msgs::StopArbiter;
use actix::*;
use server::{HttpHandler, KeepAlive};
use server::channel::HttpChannel; use server::channel::HttpChannel;
use server::settings::WorkerSettings; use server::settings::WorkerSettings;
use server::{HttpHandler, KeepAlive};
#[derive(Message)] #[derive(Message)]
pub(crate) struct Conn<T> { pub(crate) struct Conn<T> {
@ -46,9 +45,12 @@ impl Message for StopWorker {
/// Http worker /// Http worker
/// ///
/// Worker accepts Socket objects via unbounded channel and start requests processing. /// Worker accepts Socket objects via unbounded channel and start requests
pub(crate) /// processing.
struct Worker<H> where H: HttpHandler + 'static { pub(crate) struct Worker<H>
where
H: HttpHandler + 'static,
{
settings: Rc<WorkerSettings<H>>, settings: Rc<WorkerSettings<H>>,
hnd: Handle, hnd: Handle,
handler: StreamHandlerType, handler: StreamHandlerType,
@ -56,10 +58,9 @@ struct Worker<H> where H: HttpHandler + 'static {
} }
impl<H: HttpHandler + 'static> Worker<H> { impl<H: HttpHandler + 'static> Worker<H> {
pub(crate) fn new(
pub(crate) fn new(h: Vec<H>, handler: StreamHandlerType, keep_alive: KeepAlive) h: Vec<H>, handler: StreamHandlerType, keep_alive: KeepAlive
-> Worker<H> ) -> Worker<H> {
{
let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive {
Some(time::Duration::new(val as u64, 0)) Some(time::Duration::new(val as u64, 0))
} else { } else {
@ -76,11 +77,14 @@ impl<H: HttpHandler + 'static> Worker<H> {
fn update_time(&self, ctx: &mut Context<Self>) { fn update_time(&self, ctx: &mut Context<Self>) {
self.settings.update_date(); self.settings.update_date();
ctx.run_later(time::Duration::new(1, 0), |slf, ctx| slf.update_time(ctx)); ctx.run_later(time::Duration::new(1, 0), |slf, ctx| {
slf.update_time(ctx)
});
} }
fn shutdown_timeout(&self, ctx: &mut Context<Self>, fn shutdown_timeout(
tx: oneshot::Sender<bool>, dur: time::Duration) { &self, ctx: &mut Context<Self>, tx: oneshot::Sender<bool>, dur: time::Duration
) {
// sleep for 1 second and then check again // sleep for 1 second and then check again
ctx.run_later(time::Duration::new(1, 0), move |slf, ctx| { ctx.run_later(time::Duration::new(1, 0), move |slf, ctx| {
let num = slf.settings.num_channels(); let num = slf.settings.num_channels();
@ -99,7 +103,10 @@ impl<H: HttpHandler + 'static> Worker<H> {
} }
} }
impl<H: 'static> Actor for Worker<H> where H: HttpHandler + 'static { impl<H: 'static> Actor for Worker<H>
where
H: HttpHandler + 'static,
{
type Context = Context<Self>; type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) { fn started(&mut self, ctx: &mut Self::Context) {
@ -108,22 +115,24 @@ impl<H: 'static> Actor for Worker<H> where H: HttpHandler + 'static {
} }
impl<H> Handler<Conn<net::TcpStream>> for Worker<H> impl<H> Handler<Conn<net::TcpStream>> for Worker<H>
where H: HttpHandler + 'static, where
H: HttpHandler + 'static,
{ {
type Result = (); type Result = ();
fn handle(&mut self, msg: Conn<net::TcpStream>, _: &mut Context<Self>) fn handle(&mut self, msg: Conn<net::TcpStream>, _: &mut Context<Self>) {
{
if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() { if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() {
error!("Can not set socket keep-alive option"); error!("Can not set socket keep-alive option");
} }
self.handler.handle(Rc::clone(&self.settings), &self.hnd, msg); self.handler
.handle(Rc::clone(&self.settings), &self.hnd, msg);
} }
} }
/// `StopWorker` message handler /// `StopWorker` message handler
impl<H> Handler<StopWorker> for Worker<H> impl<H> Handler<StopWorker> for Worker<H>
where H: HttpHandler + 'static, where
H: HttpHandler + 'static,
{ {
type Result = Response<bool, ()>; type Result = Response<bool, ()>;
@ -148,17 +157,16 @@ impl<H> Handler<StopWorker> for Worker<H>
#[derive(Clone)] #[derive(Clone)]
pub(crate) enum StreamHandlerType { pub(crate) enum StreamHandlerType {
Normal, Normal,
#[cfg(feature="tls")] #[cfg(feature = "tls")]
Tls(TlsAcceptor), Tls(TlsAcceptor),
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
Alpn(SslAcceptor), Alpn(SslAcceptor),
} }
impl StreamHandlerType { impl StreamHandlerType {
fn handle<H: HttpHandler>(
fn handle<H: HttpHandler>(&mut self, &mut self, h: Rc<WorkerSettings<H>>, hnd: &Handle, msg: Conn<net::TcpStream>
h: Rc<WorkerSettings<H>>, ) {
hnd: &Handle, msg: Conn<net::TcpStream>) {
match *self { match *self {
StreamHandlerType::Normal => { StreamHandlerType::Normal => {
let _ = msg.io.set_nodelay(true); let _ = msg.io.set_nodelay(true);
@ -167,7 +175,7 @@ impl StreamHandlerType {
hnd.spawn(HttpChannel::new(h, io, msg.peer, msg.http2)); hnd.spawn(HttpChannel::new(h, io, msg.peer, msg.http2));
} }
#[cfg(feature="tls")] #[cfg(feature = "tls")]
StreamHandlerType::Tls(ref acceptor) => { StreamHandlerType::Tls(ref acceptor) => {
let Conn { io, peer, http2 } = msg; let Conn { io, peer, http2 } = msg;
let _ = io.set_nodelay(true); let _ = io.set_nodelay(true);
@ -177,16 +185,21 @@ impl StreamHandlerType {
hnd.spawn( hnd.spawn(
TlsAcceptorExt::accept_async(acceptor, io).then(move |res| { TlsAcceptorExt::accept_async(acceptor, io).then(move |res| {
match res { match res {
Ok(io) => Arbiter::handle().spawn( Ok(io) => Arbiter::handle().spawn(HttpChannel::new(
HttpChannel::new(h, io, peer, http2)), h,
Err(err) => io,
trace!("Error during handling tls connection: {}", err), peer,
http2,
)),
Err(err) => {
trace!("Error during handling tls connection: {}", err)
}
}; };
future::result(Ok(())) future::result(Ok(()))
}) }),
); );
} }
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
StreamHandlerType::Alpn(ref acceptor) => { StreamHandlerType::Alpn(ref acceptor) => {
let Conn { io, peer, .. } = msg; let Conn { io, peer, .. } = msg;
let _ = io.set_nodelay(true); let _ = io.set_nodelay(true);
@ -197,20 +210,26 @@ impl StreamHandlerType {
SslAcceptorExt::accept_async(acceptor, io).then(move |res| { SslAcceptorExt::accept_async(acceptor, io).then(move |res| {
match res { match res {
Ok(io) => { Ok(io) => {
let http2 = if let Some(p) = io.get_ref().ssl().selected_alpn_protocol() let http2 = if let Some(p) =
io.get_ref().ssl().selected_alpn_protocol()
{ {
p.len() == 2 && &p == b"h2" p.len() == 2 && &p == b"h2"
} else { } else {
false false
}; };
Arbiter::handle().spawn( Arbiter::handle().spawn(HttpChannel::new(
HttpChannel::new(h, io, peer, http2)); h,
}, io,
Err(err) => peer,
trace!("Error during handling tls connection: {}", err), http2,
));
}
Err(err) => {
trace!("Error during handling tls connection: {}", err)
}
}; };
future::result(Ok(())) future::result(Ok(()))
}) }),
); );
} }
} }

View File

@ -1,37 +1,37 @@
//! Various helpers for Actix applications to use during testing. //! Various helpers for Actix applications to use during testing.
use std::{net, thread};
use std::rc::Rc; use std::rc::Rc;
use std::sync::mpsc;
use std::str::FromStr; use std::str::FromStr;
use std::sync::mpsc;
use std::{net, thread};
use actix::{Actor, Arbiter, Addr, Syn, System, SystemRunner, Unsync, msgs}; use actix::{msgs, Actor, Addr, Arbiter, Syn, System, SystemRunner, Unsync};
use cookie::Cookie; use cookie::Cookie;
use http::{Uri, Method, Version, HeaderMap, HttpTryFrom};
use http::header::HeaderName;
use futures::Future; use futures::Future;
use http::header::HeaderName;
use http::{HeaderMap, HttpTryFrom, Method, Uri, Version};
use net2::TcpBuilder;
use tokio_core::net::TcpListener; use tokio_core::net::TcpListener;
use tokio_core::reactor::Core; use tokio_core::reactor::Core;
use net2::TcpBuilder;
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
use openssl::ssl::SslAcceptor; use openssl::ssl::SslAcceptor;
use ws;
use body::Binary;
use error::Error;
use header::{Header, IntoHeaderValue};
use handler::{Handler, Responder, ReplyItem};
use middleware::Middleware;
use application::{App, HttpApplication}; use application::{App, HttpApplication};
use param::Params; use body::Binary;
use router::Router; use client::{ClientConnector, ClientRequest, ClientRequestBuilder};
use payload::Payload; use error::Error;
use resource::ResourceHandler; use handler::{Handler, ReplyItem, Responder};
use header::{Header, IntoHeaderValue};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
use middleware::Middleware;
use param::Params;
use payload::Payload;
use resource::ResourceHandler;
use router::Router;
use server::{HttpServer, IntoHttpHandler, ServerSettings}; use server::{HttpServer, IntoHttpHandler, ServerSettings};
use client::{ClientRequest, ClientRequestBuilder, ClientConnector}; use ws;
/// The `TestServer` type. /// The `TestServer` type.
/// ///
@ -69,20 +69,20 @@ pub struct TestServer {
} }
impl TestServer { impl TestServer {
/// Start new test server /// Start new test server
/// ///
/// This method accepts configuration method. You can add /// This method accepts configuration method. You can add
/// middlewares or set handlers for test application. /// middlewares or set handlers for test application.
pub fn new<F>(config: F) -> Self pub fn new<F>(config: F) -> Self
where F: Sync + Send + 'static + Fn(&mut TestApp<()>) where
F: Sync + Send + 'static + Fn(&mut TestApp<()>),
{ {
TestServerBuilder::new(||()).start(config) TestServerBuilder::new(|| ()).start(config)
} }
/// Create test server builder /// Create test server builder
pub fn build() -> TestServerBuilder<()> { pub fn build() -> TestServerBuilder<()> {
TestServerBuilder::new(||()) TestServerBuilder::new(|| ())
} }
/// Create test server builder with specific state factory /// Create test server builder with specific state factory
@ -91,17 +91,19 @@ impl TestServer {
/// Also it can be used for external dependecy initialization, /// Also it can be used for external dependecy initialization,
/// like creating sync actors for diesel integration. /// like creating sync actors for diesel integration.
pub fn build_with_state<F, S>(state: F) -> TestServerBuilder<S> pub fn build_with_state<F, S>(state: F) -> TestServerBuilder<S>
where F: Fn() -> S + Sync + Send + 'static, where
S: 'static, F: Fn() -> S + Sync + Send + 'static,
S: 'static,
{ {
TestServerBuilder::new(state) TestServerBuilder::new(state)
} }
/// Start new test server with application factory /// Start new test server with application factory
pub fn with_factory<F, U, H>(factory: F) -> Self pub fn with_factory<F, U, H>(factory: F) -> Self
where F: Fn() -> U + Sync + Send + 'static, where
U: IntoIterator<Item=H> + 'static, F: Fn() -> U + Sync + Send + 'static,
H: IntoHttpHandler + 'static, U: IntoIterator<Item = H> + 'static,
H: IntoHttpHandler + 'static,
{ {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
@ -110,8 +112,8 @@ impl TestServer {
let sys = System::new("actix-test-server"); let sys = System::new("actix-test-server");
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap(); let local_addr = tcp.local_addr().unwrap();
let tcp = TcpListener::from_listener( let tcp =
tcp, &local_addr, Arbiter::handle()).unwrap(); TcpListener::from_listener(tcp, &local_addr, Arbiter::handle()).unwrap();
HttpServer::new(factory) HttpServer::new(factory)
.disable_signals() .disable_signals()
@ -134,15 +136,15 @@ impl TestServer {
} }
fn get_conn() -> Addr<Unsync, ClientConnector> { fn get_conn() -> Addr<Unsync, ClientConnector> {
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
{ {
use openssl::ssl::{SslMethod, SslConnector, SslVerifyMode}; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE); builder.set_verify(SslVerifyMode::NONE);
ClientConnector::with_connector(builder.build()).start() ClientConnector::with_connector(builder.build()).start()
} }
#[cfg(not(feature="alpn"))] #[cfg(not(feature = "alpn"))]
{ {
ClientConnector::default().start() ClientConnector::default().start()
} }
@ -166,9 +168,19 @@ impl TestServer {
/// Construct test server url /// Construct test server url
pub fn url(&self, uri: &str) -> String { pub fn url(&self, uri: &str) -> String {
if uri.starts_with('/') { if uri.starts_with('/') {
format!("{}://{}{}", if self.ssl {"https"} else {"http"}, self.addr, uri) format!(
"{}://{}{}",
if self.ssl { "https" } else { "http" },
self.addr,
uri
)
} else { } else {
format!("{}://{}/{}", if self.ssl {"https"} else {"http"}, self.addr, uri) format!(
"{}://{}/{}",
if self.ssl { "https" } else { "http" },
self.addr,
uri
)
} }
} }
@ -182,16 +194,20 @@ impl TestServer {
/// Execute future on current core /// Execute future on current core
pub fn execute<F, I, E>(&mut self, fut: F) -> Result<I, E> pub fn execute<F, I, E>(&mut self, fut: F) -> Result<I, E>
where F: Future<Item=I, Error=E> where
F: Future<Item = I, Error = E>,
{ {
self.system.run_until_complete(fut) self.system.run_until_complete(fut)
} }
/// Connect to websocket server /// Connect to websocket server
pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> { pub fn ws(
&mut self
) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> {
let url = self.url("/"); let url = self.url("/");
self.system.run_until_complete( self.system.run_until_complete(
ws::Client::with_connector(url, self.conn.clone()).connect()) ws::Client::with_connector(url, self.conn.clone()).connect(),
)
} }
/// Create `GET` request /// Create `GET` request
@ -231,23 +247,23 @@ impl Drop for TestServer {
/// builder-like pattern. /// builder-like pattern.
pub struct TestServerBuilder<S> { pub struct TestServerBuilder<S> {
state: Box<Fn() -> S + Sync + Send + 'static>, state: Box<Fn() -> S + Sync + Send + 'static>,
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
ssl: Option<SslAcceptor>, ssl: Option<SslAcceptor>,
} }
impl<S: 'static> TestServerBuilder<S> { impl<S: 'static> TestServerBuilder<S> {
pub fn new<F>(state: F) -> TestServerBuilder<S> pub fn new<F>(state: F) -> TestServerBuilder<S>
where F: Fn() -> S + Sync + Send + 'static where
F: Fn() -> S + Sync + Send + 'static,
{ {
TestServerBuilder { TestServerBuilder {
state: Box::new(state), state: Box::new(state),
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
ssl: None, ssl: None,
} }
} }
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
/// Create ssl server /// Create ssl server
pub fn ssl(mut self, ssl: SslAcceptor) -> Self { pub fn ssl(mut self, ssl: SslAcceptor) -> Self {
self.ssl = Some(ssl); self.ssl = Some(ssl);
@ -257,13 +273,14 @@ impl<S: 'static> TestServerBuilder<S> {
#[allow(unused_mut)] #[allow(unused_mut)]
/// Configure test application and run test server /// Configure test application and run test server
pub fn start<F>(mut self, config: F) -> TestServer pub fn start<F>(mut self, config: F) -> TestServer
where F: Sync + Send + 'static + Fn(&mut TestApp<S>), where
F: Sync + Send + 'static + Fn(&mut TestApp<S>),
{ {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
let ssl = self.ssl.is_some(); let ssl = self.ssl.is_some();
#[cfg(not(feature="alpn"))] #[cfg(not(feature = "alpn"))]
let ssl = false; let ssl = false;
// run server in separate thread // run server in separate thread
@ -272,38 +289,38 @@ impl<S: 'static> TestServerBuilder<S> {
let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = tcp.local_addr().unwrap(); let local_addr = tcp.local_addr().unwrap();
let tcp = TcpListener::from_listener( let tcp =
tcp, &local_addr, Arbiter::handle()).unwrap(); TcpListener::from_listener(tcp, &local_addr, Arbiter::handle()).unwrap();
let state = self.state; let state = self.state;
let srv = HttpServer::new(move || { let srv = HttpServer::new(move || {
let mut app = TestApp::new(state()); let mut app = TestApp::new(state());
config(&mut app); config(&mut app);
vec![app]}) vec![app]
.disable_signals(); }).disable_signals();
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
{ {
use std::io;
use futures::Stream; use futures::Stream;
use std::io;
use tokio_openssl::SslAcceptorExt; use tokio_openssl::SslAcceptorExt;
let ssl = self.ssl.take(); let ssl = self.ssl.take();
if let Some(ssl) = ssl { if let Some(ssl) = ssl {
srv.start_incoming( srv.start_incoming(
tcp.incoming() tcp.incoming().and_then(move |(sock, addr)| {
.and_then(move |(sock, addr)| { ssl.accept_async(sock)
ssl.accept_async(sock) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) .map(move |s| (s, addr))
.map(move |s| (s, addr)) }),
}), false,
false); );
} else { } else {
srv.start_incoming(tcp.incoming(), false); srv.start_incoming(tcp.incoming(), false);
} }
} }
#[cfg(not(feature="alpn"))] #[cfg(not(feature = "alpn"))]
{ {
srv.start_incoming(tcp.incoming(), false); srv.start_incoming(tcp.incoming(), false);
} }
@ -326,24 +343,30 @@ impl<S: 'static> TestServerBuilder<S> {
} }
/// Test application helper for testing request handlers. /// Test application helper for testing request handlers.
pub struct TestApp<S=()> { pub struct TestApp<S = ()> {
app: Option<App<S>>, app: Option<App<S>>,
} }
impl<S: 'static> TestApp<S> { impl<S: 'static> TestApp<S> {
fn new(state: S) -> TestApp<S> { fn new(state: S) -> TestApp<S> {
let app = App::with_state(state); let app = App::with_state(state);
TestApp{app: Some(app)} TestApp { app: Some(app) }
} }
/// Register handler for "/" /// Register handler for "/"
pub fn handler<H: Handler<S>>(&mut self, handler: H) { pub fn handler<H: Handler<S>>(&mut self, handler: H) {
self.app = Some(self.app.take().unwrap().resource("/", |r| r.h(handler))); self.app = Some(
self.app
.take()
.unwrap()
.resource("/", |r| r.h(handler)),
);
} }
/// Register middleware /// Register middleware
pub fn middleware<T>(&mut self, mw: T) -> &mut TestApp<S> pub fn middleware<T>(&mut self, mw: T) -> &mut TestApp<S>
where T: Middleware<S> + 'static where
T: Middleware<S> + 'static,
{ {
self.app = Some(self.app.take().unwrap().middleware(mw)); self.app = Some(self.app.take().unwrap().middleware(mw));
self self
@ -352,7 +375,8 @@ impl<S: 'static> TestApp<S> {
/// Register resource. This method is similar /// Register resource. This method is similar
/// to `App::resource()` method. /// to `App::resource()` method.
pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut TestApp<S> pub fn resource<F, R>(&mut self, path: &str, f: F) -> &mut TestApp<S>
where F: FnOnce(&mut ResourceHandler<S>) -> R + 'static where
F: FnOnce(&mut ResourceHandler<S>) -> R + 'static,
{ {
self.app = Some(self.app.take().unwrap().resource(path, f)); self.app = Some(self.app.take().unwrap().resource(path, f));
self self
@ -419,7 +443,6 @@ pub struct TestRequest<S> {
} }
impl Default for TestRequest<()> { impl Default for TestRequest<()> {
fn default() -> TestRequest<()> { fn default() -> TestRequest<()> {
TestRequest { TestRequest {
state: (), state: (),
@ -435,28 +458,27 @@ impl Default for TestRequest<()> {
} }
impl TestRequest<()> { impl TestRequest<()> {
/// Create TestRequest and set request uri /// Create TestRequest and set request uri
pub fn with_uri(path: &str) -> TestRequest<()> { pub fn with_uri(path: &str) -> TestRequest<()> {
TestRequest::default().uri(path) TestRequest::default().uri(path)
} }
/// Create TestRequest and set header /// Create TestRequest and set header
pub fn with_hdr<H: Header>(hdr: H) -> TestRequest<()> pub fn with_hdr<H: Header>(hdr: H) -> TestRequest<()> {
{
TestRequest::default().set(hdr) TestRequest::default().set(hdr)
} }
/// Create TestRequest and set header /// Create TestRequest and set header
pub fn with_header<K, V>(key: K, value: V) -> TestRequest<()> pub fn with_header<K, V>(key: K, value: V) -> TestRequest<()>
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue, where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{ {
TestRequest::default().header(key, value) TestRequest::default().header(key, value)
} }
} }
impl<S> TestRequest<S> { impl<S> TestRequest<S> {
/// Start HttpRequest build process with application state /// Start HttpRequest build process with application state
pub fn with_state(state: S) -> TestRequest<S> { pub fn with_state(state: S) -> TestRequest<S> {
TestRequest { TestRequest {
@ -490,23 +512,24 @@ impl<S> TestRequest<S> {
} }
/// Set a header /// Set a header
pub fn set<H: Header>(mut self, hdr: H) -> Self pub fn set<H: Header>(mut self, hdr: H) -> Self {
{
if let Ok(value) = hdr.try_into() { if let Ok(value) = hdr.try_into() {
self.headers.append(H::name(), value); self.headers.append(H::name(), value);
return self return self;
} }
panic!("Can not set header"); panic!("Can not set header");
} }
/// Set a header /// Set a header
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{ {
if let Ok(key) = HeaderName::try_from(key) { if let Ok(key) = HeaderName::try_from(key) {
if let Ok(value) = value.try_into() { if let Ok(value) = value.try_into() {
self.headers.append(key, value); self.headers.append(key, value);
return self return self;
} }
} }
panic!("Can not create header"); panic!("Can not create header");
@ -529,7 +552,16 @@ impl<S> TestRequest<S> {
/// Complete request creation and generate `HttpRequest` instance /// Complete request creation and generate `HttpRequest` instance
pub fn finish(self) -> HttpRequest<S> { pub fn finish(self) -> HttpRequest<S> {
let TestRequest { state, method, uri, version, headers, params, cookies, payload } = self; let TestRequest {
state,
method,
uri,
version,
headers,
params,
cookies,
payload,
} = self;
let req = HttpRequest::new(method, uri, version, headers, payload); let req = HttpRequest::new(method, uri, version, headers, payload);
req.as_mut().cookies = cookies; req.as_mut().cookies = cookies;
req.as_mut().params = params; req.as_mut().params = params;
@ -540,8 +572,16 @@ impl<S> TestRequest<S> {
#[cfg(test)] #[cfg(test)]
/// Complete request creation and generate `HttpRequest` instance /// Complete request creation and generate `HttpRequest` instance
pub(crate) fn finish_with_router(self, router: Router) -> HttpRequest<S> { pub(crate) fn finish_with_router(self, router: Router) -> HttpRequest<S> {
let TestRequest { state, method, uri, let TestRequest {
version, headers, params, cookies, payload } = self; state,
method,
uri,
version,
headers,
params,
cookies,
payload,
} = self;
let req = HttpRequest::new(method, uri, version, headers, payload); let req = HttpRequest::new(method, uri, version, headers, payload);
req.as_mut().cookies = cookies; req.as_mut().cookies = cookies;
@ -553,18 +593,16 @@ impl<S> TestRequest<S> {
/// with generated request. /// with generated request.
/// ///
/// This method panics is handler returns actor or async result. /// This method panics is handler returns actor or async result.
pub fn run<H: Handler<S>>(self, mut h: H) -> pub fn run<H: Handler<S>>(
Result<HttpResponse, <<H as Handler<S>>::Result as Responder>::Error> self, mut h: H
{ ) -> Result<HttpResponse, <<H as Handler<S>>::Result as Responder>::Error> {
let req = self.finish(); let req = self.finish();
let resp = h.handle(req.clone()); let resp = h.handle(req.clone());
match resp.respond_to(req.drop_state()) { match resp.respond_to(req.drop_state()) {
Ok(resp) => { Ok(resp) => match resp.into().into() {
match resp.into().into() { ReplyItem::Message(resp) => Ok(resp),
ReplyItem::Message(resp) => Ok(resp), ReplyItem::Future(_) => panic!("Async handler is not supported."),
ReplyItem::Future(_) => panic!("Async handler is not supported."),
}
}, },
Err(err) => Err(err), Err(err) => Err(err),
} }
@ -575,24 +613,23 @@ impl<S> TestRequest<S> {
/// ///
/// This method panics is handler returns actor. /// This method panics is handler returns actor.
pub fn run_async<H, R, F, E>(self, h: H) -> Result<HttpResponse, E> pub fn run_async<H, R, F, E>(self, h: H) -> Result<HttpResponse, E>
where H: Fn(HttpRequest<S>) -> F + 'static, where
F: Future<Item=R, Error=E> + 'static, H: Fn(HttpRequest<S>) -> F + 'static,
R: Responder<Error=E> + 'static, F: Future<Item = R, Error = E> + 'static,
E: Into<Error> + 'static R: Responder<Error = E> + 'static,
E: Into<Error> + 'static,
{ {
let req = self.finish(); let req = self.finish();
let fut = h(req.clone()); let fut = h(req.clone());
let mut core = Core::new().unwrap(); let mut core = Core::new().unwrap();
match core.run(fut) { match core.run(fut) {
Ok(r) => { Ok(r) => match r.respond_to(req.drop_state()) {
match r.respond_to(req.drop_state()) { Ok(reply) => match reply.into().into() {
Ok(reply) => match reply.into().into() { ReplyItem::Message(resp) => Ok(resp),
ReplyItem::Message(resp) => Ok(resp), _ => panic!("Nested async replies are not supported"),
_ => panic!("Nested async replies are not supported"), },
}, Err(e) => Err(e),
Err(e) => Err(e),
}
}, },
Err(err) => Err(err), Err(err) => Err(err),
} }

View File

@ -1,33 +1,37 @@
use std::rc::Rc; use futures::{Async, Future, Poll};
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use futures::{Async, Future, Poll}; use std::rc::Rc;
use error::Error; use error::Error;
use handler::{Handler, FromRequest, Reply, ReplyItem, Responder}; use handler::{FromRequest, Handler, Reply, ReplyItem, Responder};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::HttpResponse; use httpresponse::HttpResponse;
pub struct ExtractorConfig<S: 'static, T: FromRequest<S>> { pub struct ExtractorConfig<S: 'static, T: FromRequest<S>> {
cfg: Rc<UnsafeCell<T::Config>> cfg: Rc<UnsafeCell<T::Config>>,
} }
impl<S: 'static, T: FromRequest<S>> Default for ExtractorConfig<S, T> { impl<S: 'static, T: FromRequest<S>> Default for ExtractorConfig<S, T> {
fn default() -> Self { fn default() -> Self {
ExtractorConfig { cfg: Rc::new(UnsafeCell::new(T::Config::default())) } ExtractorConfig {
cfg: Rc::new(UnsafeCell::new(T::Config::default())),
}
} }
} }
impl<S: 'static, T: FromRequest<S>> Clone for ExtractorConfig<S, T> { impl<S: 'static, T: FromRequest<S>> Clone for ExtractorConfig<S, T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
ExtractorConfig { cfg: Rc::clone(&self.cfg) } ExtractorConfig {
cfg: Rc::clone(&self.cfg),
}
} }
} }
impl<S: 'static, T: FromRequest<S>> AsRef<T::Config> for ExtractorConfig<S, T> { impl<S: 'static, T: FromRequest<S>> AsRef<T::Config> for ExtractorConfig<S, T> {
fn as_ref(&self) -> &T::Config { fn as_ref(&self) -> &T::Config {
unsafe{&*self.cfg.get()} unsafe { &*self.cfg.get() }
} }
} }
@ -35,18 +39,21 @@ impl<S: 'static, T: FromRequest<S>> Deref for ExtractorConfig<S, T> {
type Target = T::Config; type Target = T::Config;
fn deref(&self) -> &T::Config { fn deref(&self) -> &T::Config {
unsafe{&*self.cfg.get()} unsafe { &*self.cfg.get() }
} }
} }
impl<S: 'static, T: FromRequest<S>> DerefMut for ExtractorConfig<S, T> { impl<S: 'static, T: FromRequest<S>> DerefMut for ExtractorConfig<S, T> {
fn deref_mut(&mut self) -> &mut T::Config { fn deref_mut(&mut self) -> &mut T::Config {
unsafe{&mut *self.cfg.get()} unsafe { &mut *self.cfg.get() }
} }
} }
pub struct With<T, S, F, R> pub struct With<T, S, F, R>
where F: Fn(T) -> R, T: FromRequest<S>, S: 'static, where
F: Fn(T) -> R,
T: FromRequest<S>,
S: 'static,
{ {
hnd: Rc<UnsafeCell<F>>, hnd: Rc<UnsafeCell<F>>,
cfg: ExtractorConfig<S, T>, cfg: ExtractorConfig<S, T>,
@ -54,23 +61,31 @@ pub struct With<T, S, F, R>
} }
impl<T, S, F, R> With<T, S, F, R> impl<T, S, F, R> With<T, S, F, R>
where F: Fn(T) -> R, T: FromRequest<S>, S: 'static, where
F: Fn(T) -> R,
T: FromRequest<S>,
S: 'static,
{ {
pub fn new(f: F, cfg: ExtractorConfig<S, T>) -> Self { pub fn new(f: F, cfg: ExtractorConfig<S, T>) -> Self {
With{cfg, hnd: Rc::new(UnsafeCell::new(f)), _s: PhantomData} With {
cfg,
hnd: Rc::new(UnsafeCell::new(f)),
_s: PhantomData,
}
} }
} }
impl<T, S, F, R> Handler<S> for With<T, S, F, R> impl<T, S, F, R> Handler<S> for With<T, S, F, R>
where F: Fn(T) -> R + 'static, where
R: Responder + 'static, F: Fn(T) -> R + 'static,
T: FromRequest<S> + 'static, R: Responder + 'static,
S: 'static T: FromRequest<S> + 'static,
S: 'static,
{ {
type Result = Reply; type Result = Reply;
fn handle(&mut self, req: HttpRequest<S>) -> Self::Result { fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
let mut fut = WithHandlerFut{ let mut fut = WithHandlerFut {
req, req,
started: false, started: false,
hnd: Rc::clone(&self.hnd), hnd: Rc::clone(&self.hnd),
@ -88,31 +103,33 @@ impl<T, S, F, R> Handler<S> for With<T, S, F, R>
} }
struct WithHandlerFut<T, S, F, R> struct WithHandlerFut<T, S, F, R>
where F: Fn(T) -> R, where
R: Responder, F: Fn(T) -> R,
T: FromRequest<S> + 'static, R: Responder,
S: 'static T: FromRequest<S> + 'static,
S: 'static,
{ {
started: bool, started: bool,
hnd: Rc<UnsafeCell<F>>, hnd: Rc<UnsafeCell<F>>,
cfg: ExtractorConfig<S, T>, cfg: ExtractorConfig<S, T>,
req: HttpRequest<S>, req: HttpRequest<S>,
fut1: Option<Box<Future<Item=T, Error=Error>>>, fut1: Option<Box<Future<Item = T, Error = Error>>>,
fut2: Option<Box<Future<Item=HttpResponse, Error=Error>>>, fut2: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
} }
impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R> impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
where F: Fn(T) -> R, where
R: Responder + 'static, F: Fn(T) -> R,
T: FromRequest<S> + 'static, R: Responder + 'static,
S: 'static T: FromRequest<S> + 'static,
S: 'static,
{ {
type Item = HttpResponse; type Item = HttpResponse;
type Error = Error; type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut2 { if let Some(ref mut fut) = self.fut2 {
return fut.poll() return fut.poll();
} }
let item = if !self.started { let item = if !self.started {
@ -122,8 +139,8 @@ impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
Ok(Async::Ready(item)) => item, Ok(Async::Ready(item)) => item,
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.fut1 = Some(Box::new(fut)); self.fut1 = Some(Box::new(fut));
return Ok(Async::NotReady) return Ok(Async::NotReady);
}, }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
} else { } else {
@ -133,7 +150,7 @@ impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
} }
}; };
let hnd: &mut F = unsafe{&mut *self.hnd.get()}; let hnd: &mut F = unsafe { &mut *self.hnd.get() };
let item = match (*hnd)(item).respond_to(self.req.drop_state()) { let item = match (*hnd)(item).respond_to(self.req.drop_state()) {
Ok(item) => item.into(), Ok(item) => item.into(),
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
@ -150,8 +167,11 @@ impl<T, S, F, R> Future for WithHandlerFut<T, S, F, R>
} }
pub struct With2<T1, T2, S, F, R> pub struct With2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R, where
T1: FromRequest<S> + 'static, T2: FromRequest<S> + 'static, S: 'static F: Fn(T1, T2) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{ {
hnd: Rc<UnsafeCell<F>>, hnd: Rc<UnsafeCell<F>>,
cfg1: ExtractorConfig<S, T1>, cfg1: ExtractorConfig<S, T1>,
@ -160,26 +180,36 @@ pub struct With2<T1, T2, S, F, R>
} }
impl<T1, T2, S, F, R> With2<T1, T2, S, F, R> impl<T1, T2, S, F, R> With2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R, where
T1: FromRequest<S> + 'static, T2: FromRequest<S> + 'static, S: 'static F: Fn(T1, T2) -> R,
T1: FromRequest<S> + 'static,
T2: FromRequest<S> + 'static,
S: 'static,
{ {
pub fn new(f: F, cfg1: ExtractorConfig<S, T1>, cfg2: ExtractorConfig<S, T2>) -> Self { pub fn new(
With2{hnd: Rc::new(UnsafeCell::new(f)), f: F, cfg1: ExtractorConfig<S, T1>, cfg2: ExtractorConfig<S, T2>
cfg1, cfg2, _s: PhantomData} ) -> Self {
With2 {
hnd: Rc::new(UnsafeCell::new(f)),
cfg1,
cfg2,
_s: PhantomData,
}
} }
} }
impl<T1, T2, S, F, R> Handler<S> for With2<T1, T2, S, F, R> impl<T1, T2, S, F, R> Handler<S> for With2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R + 'static, where
R: Responder + 'static, F: Fn(T1, T2) -> R + 'static,
T1: FromRequest<S> + 'static, R: Responder + 'static,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
S: 'static T2: FromRequest<S> + 'static,
S: 'static,
{ {
type Result = Reply; type Result = Reply;
fn handle(&mut self, req: HttpRequest<S>) -> Self::Result { fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
let mut fut = WithHandlerFut2{ let mut fut = WithHandlerFut2 {
req, req,
started: false, started: false,
hnd: Rc::clone(&self.hnd), hnd: Rc::clone(&self.hnd),
@ -199,11 +229,12 @@ impl<T1, T2, S, F, R> Handler<S> for With2<T1, T2, S, F, R>
} }
struct WithHandlerFut2<T1, T2, S, F, R> struct WithHandlerFut2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R + 'static, where
R: Responder + 'static, F: Fn(T1, T2) -> R + 'static,
T1: FromRequest<S> + 'static, R: Responder + 'static,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
S: 'static T2: FromRequest<S> + 'static,
S: 'static,
{ {
started: bool, started: bool,
hnd: Rc<UnsafeCell<F>>, hnd: Rc<UnsafeCell<F>>,
@ -211,24 +242,25 @@ struct WithHandlerFut2<T1, T2, S, F, R>
cfg2: ExtractorConfig<S, T2>, cfg2: ExtractorConfig<S, T2>,
req: HttpRequest<S>, req: HttpRequest<S>,
item: Option<T1>, item: Option<T1>,
fut1: Option<Box<Future<Item=T1, Error=Error>>>, fut1: Option<Box<Future<Item = T1, Error = Error>>>,
fut2: Option<Box<Future<Item=T2, Error=Error>>>, fut2: Option<Box<Future<Item = T2, Error = Error>>>,
fut3: Option<Box<Future<Item=HttpResponse, Error=Error>>>, fut3: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
} }
impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R> impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
where F: Fn(T1, T2) -> R + 'static, where
R: Responder + 'static, F: Fn(T1, T2) -> R + 'static,
T1: FromRequest<S> + 'static, R: Responder + 'static,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
S: 'static T2: FromRequest<S> + 'static,
S: 'static,
{ {
type Item = HttpResponse; type Item = HttpResponse;
type Error = Error; type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut3 { if let Some(ref mut fut) = self.fut3 {
return fut.poll() return fut.poll();
} }
if !self.started { if !self.started {
@ -236,32 +268,32 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
let mut fut = T1::from_request(&self.req, self.cfg1.as_ref()); let mut fut = T1::from_request(&self.req, self.cfg1.as_ref());
match fut.poll() { match fut.poll() {
Ok(Async::Ready(item1)) => { Ok(Async::Ready(item1)) => {
let mut fut = T2::from_request(&self.req,self.cfg2.as_ref()); let mut fut = T2::from_request(&self.req, self.cfg2.as_ref());
match fut.poll() { match fut.poll() {
Ok(Async::Ready(item2)) => { Ok(Async::Ready(item2)) => {
let hnd: &mut F = unsafe{&mut *self.hnd.get()}; let hnd: &mut F = unsafe { &mut *self.hnd.get() };
match (*hnd)(item1, item2) match (*hnd)(item1, item2).respond_to(self.req.drop_state())
.respond_to(self.req.drop_state())
{ {
Ok(item) => match item.into().into() { Ok(item) => match item.into().into() {
ReplyItem::Message(resp) => ReplyItem::Message(resp) => {
return Ok(Async::Ready(resp)), return Ok(Async::Ready(resp))
}
ReplyItem::Future(fut) => { ReplyItem::Future(fut) => {
self.fut3 = Some(fut); self.fut3 = Some(fut);
return self.poll() return self.poll();
} }
}, },
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
} }
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.item = Some(item1); self.item = Some(item1);
self.fut2 = Some(Box::new(fut)); self.fut2 = Some(Box::new(fut));
return Ok(Async::NotReady); return Ok(Async::NotReady);
}, }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.fut1 = Some(Box::new(fut)); self.fut1 = Some(Box::new(fut));
return Ok(Async::NotReady); return Ok(Async::NotReady);
@ -275,9 +307,11 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
Async::Ready(item) => { Async::Ready(item) => {
self.item = Some(item); self.item = Some(item);
self.fut1.take(); self.fut1.take();
self.fut2 = Some(Box::new( self.fut2 = Some(Box::new(T2::from_request(
T2::from_request(&self.req, self.cfg2.as_ref()))); &self.req,
}, self.cfg2.as_ref(),
)));
}
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
} }
} }
@ -287,7 +321,7 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
}; };
let hnd: &mut F = unsafe{&mut *self.hnd.get()}; let hnd: &mut F = unsafe { &mut *self.hnd.get() };
let item = match (*hnd)(self.item.take().unwrap(), item) let item = match (*hnd)(self.item.take().unwrap(), item)
.respond_to(self.req.drop_state()) .respond_to(self.req.drop_state())
{ {
@ -305,11 +339,12 @@ impl<T1, T2, S, F, R> Future for WithHandlerFut2<T1, T2, S, F, R>
} }
pub struct With3<T1, T2, T3, S, F, R> pub struct With3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R, where
T1: FromRequest<S> + 'static, F: Fn(T1, T2, T3) -> R,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static, T2: FromRequest<S> + 'static,
S: 'static T3: FromRequest<S> + 'static,
S: 'static,
{ {
hnd: Rc<UnsafeCell<F>>, hnd: Rc<UnsafeCell<F>>,
cfg1: ExtractorConfig<S, T1>, cfg1: ExtractorConfig<S, T1>,
@ -318,33 +353,44 @@ pub struct With3<T1, T2, T3, S, F, R>
_s: PhantomData<S>, _s: PhantomData<S>,
} }
impl<T1, T2, T3, S, F, R> With3<T1, T2, T3, S, F, R> impl<T1, T2, T3, S, F, R> With3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R, where
T1: FromRequest<S> + 'static, F: Fn(T1, T2, T3) -> R,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static, T2: FromRequest<S> + 'static,
S: 'static T3: FromRequest<S> + 'static,
S: 'static,
{ {
pub fn new(f: F, cfg1: ExtractorConfig<S, T1>, pub fn new(
cfg2: ExtractorConfig<S, T2>, cfg3: ExtractorConfig<S, T3>) -> Self f: F, cfg1: ExtractorConfig<S, T1>, cfg2: ExtractorConfig<S, T2>,
{ cfg3: ExtractorConfig<S, T3>,
With3{hnd: Rc::new(UnsafeCell::new(f)), cfg1, cfg2, cfg3, _s: PhantomData} ) -> Self {
With3 {
hnd: Rc::new(UnsafeCell::new(f)),
cfg1,
cfg2,
cfg3,
_s: PhantomData,
}
} }
} }
impl<T1, T2, T3, S, F, R> Handler<S> for With3<T1, T2, T3, S, F, R> impl<T1, T2, T3, S, F, R> Handler<S> for With3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R + 'static, where
R: Responder + 'static, F: Fn(T1, T2, T3) -> R + 'static,
T1: FromRequest<S>, R: Responder + 'static,
T2: FromRequest<S>, T1: FromRequest<S>,
T3: FromRequest<S>, T2: FromRequest<S>,
T1: 'static, T2: 'static, T3: 'static, S: 'static T3: FromRequest<S>,
T1: 'static,
T2: 'static,
T3: 'static,
S: 'static,
{ {
type Result = Reply; type Result = Reply;
fn handle(&mut self, req: HttpRequest<S>) -> Self::Result { fn handle(&mut self, req: HttpRequest<S>) -> Self::Result {
let mut fut = WithHandlerFut3{ let mut fut = WithHandlerFut3 {
req, req,
hnd: Rc::clone(&self.hnd), hnd: Rc::clone(&self.hnd),
cfg1: self.cfg1.clone(), cfg1: self.cfg1.clone(),
@ -367,12 +413,13 @@ impl<T1, T2, T3, S, F, R> Handler<S> for With3<T1, T2, T3, S, F, R>
} }
struct WithHandlerFut3<T1, T2, T3, S, F, R> struct WithHandlerFut3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R + 'static, where
R: Responder + 'static, F: Fn(T1, T2, T3) -> R + 'static,
T1: FromRequest<S> + 'static, R: Responder + 'static,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static, T2: FromRequest<S> + 'static,
S: 'static T3: FromRequest<S> + 'static,
S: 'static,
{ {
hnd: Rc<UnsafeCell<F>>, hnd: Rc<UnsafeCell<F>>,
req: HttpRequest<S>, req: HttpRequest<S>,
@ -382,26 +429,27 @@ struct WithHandlerFut3<T1, T2, T3, S, F, R>
started: bool, started: bool,
item1: Option<T1>, item1: Option<T1>,
item2: Option<T2>, item2: Option<T2>,
fut1: Option<Box<Future<Item=T1, Error=Error>>>, fut1: Option<Box<Future<Item = T1, Error = Error>>>,
fut2: Option<Box<Future<Item=T2, Error=Error>>>, fut2: Option<Box<Future<Item = T2, Error = Error>>>,
fut3: Option<Box<Future<Item=T3, Error=Error>>>, fut3: Option<Box<Future<Item = T3, Error = Error>>>,
fut4: Option<Box<Future<Item=HttpResponse, Error=Error>>>, fut4: Option<Box<Future<Item = HttpResponse, Error = Error>>>,
} }
impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R> impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
where F: Fn(T1, T2, T3) -> R + 'static, where
R: Responder + 'static, F: Fn(T1, T2, T3) -> R + 'static,
T1: FromRequest<S> + 'static, R: Responder + 'static,
T2: FromRequest<S> + 'static, T1: FromRequest<S> + 'static,
T3: FromRequest<S> + 'static, T2: FromRequest<S> + 'static,
S: 'static T3: FromRequest<S> + 'static,
S: 'static,
{ {
type Item = HttpResponse; type Item = HttpResponse;
type Error = Error; type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut4 { if let Some(ref mut fut) = self.fut4 {
return fut.poll() return fut.poll();
} }
if !self.started { if !self.started {
@ -412,41 +460,43 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
let mut fut = T2::from_request(&self.req, self.cfg2.as_ref()); let mut fut = T2::from_request(&self.req, self.cfg2.as_ref());
match fut.poll() { match fut.poll() {
Ok(Async::Ready(item2)) => { Ok(Async::Ready(item2)) => {
let mut fut = T3::from_request(&self.req, self.cfg3.as_ref()); let mut fut =
T3::from_request(&self.req, self.cfg3.as_ref());
match fut.poll() { match fut.poll() {
Ok(Async::Ready(item3)) => { Ok(Async::Ready(item3)) => {
let hnd: &mut F = unsafe{&mut *self.hnd.get()}; let hnd: &mut F = unsafe { &mut *self.hnd.get() };
match (*hnd)(item1, item2, item3) match (*hnd)(item1, item2, item3)
.respond_to(self.req.drop_state()) .respond_to(self.req.drop_state())
{ {
Ok(item) => match item.into().into() { Ok(item) => match item.into().into() {
ReplyItem::Message(resp) => ReplyItem::Message(resp) => {
return Ok(Async::Ready(resp)), return Ok(Async::Ready(resp))
}
ReplyItem::Future(fut) => { ReplyItem::Future(fut) => {
self.fut4 = Some(fut); self.fut4 = Some(fut);
return self.poll() return self.poll();
} }
}, },
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
} }
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.item1 = Some(item1); self.item1 = Some(item1);
self.item2 = Some(item2); self.item2 = Some(item2);
self.fut3 = Some(Box::new(fut)); self.fut3 = Some(Box::new(fut));
return Ok(Async::NotReady); return Ok(Async::NotReady);
}, }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.item1 = Some(item1); self.item1 = Some(item1);
self.fut2 = Some(Box::new(fut)); self.fut2 = Some(Box::new(fut));
return Ok(Async::NotReady); return Ok(Async::NotReady);
}, }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
}, }
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.fut1 = Some(Box::new(fut)); self.fut1 = Some(Box::new(fut));
return Ok(Async::NotReady); return Ok(Async::NotReady);
@ -460,9 +510,11 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
Async::Ready(item) => { Async::Ready(item) => {
self.item1 = Some(item); self.item1 = Some(item);
self.fut1.take(); self.fut1.take();
self.fut2 = Some(Box::new( self.fut2 = Some(Box::new(T2::from_request(
T2::from_request(&self.req, self.cfg2.as_ref()))); &self.req,
}, self.cfg2.as_ref(),
)));
}
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
} }
} }
@ -472,9 +524,11 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
Async::Ready(item) => { Async::Ready(item) => {
self.item2 = Some(item); self.item2 = Some(item);
self.fut2.take(); self.fut2.take();
self.fut3 = Some(Box::new( self.fut3 = Some(Box::new(T3::from_request(
T3::from_request(&self.req, self.cfg3.as_ref()))); &self.req,
}, self.cfg3.as_ref(),
)));
}
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
} }
} }
@ -484,11 +538,12 @@ impl<T1, T2, T3, S, F, R> Future for WithHandlerFut3<T1, T2, T3, S, F, R>
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
}; };
let hnd: &mut F = unsafe{&mut *self.hnd.get()}; let hnd: &mut F = unsafe { &mut *self.hnd.get() };
let item = match (*hnd)(self.item1.take().unwrap(), let item = match (*hnd)(
self.item2.take().unwrap(), self.item1.take().unwrap(),
item) self.item2.take().unwrap(),
.respond_to(self.req.drop_state()) item,
).respond_to(self.req.drop_state())
{ {
Ok(item) => item.into(), Ok(item) => item.into(),
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),

View File

@ -1,67 +1,65 @@
//! Http client request //! Http client request
use std::{fmt, io, str};
use std::rc::Rc;
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::rc::Rc;
use std::time::Duration; use std::time::Duration;
use std::{fmt, io, str};
use base64; use base64;
use rand; use byteorder::{ByteOrder, NetworkEndian};
use bytes::Bytes; use bytes::Bytes;
use cookie::Cookie; use cookie::Cookie;
use byteorder::{ByteOrder, NetworkEndian};
use http::{HttpTryFrom, StatusCode, Error as HttpError};
use http::header::{self, HeaderName, HeaderValue};
use sha1::Sha1;
use futures::{Async, Future, Poll, Stream};
use futures::unsync::mpsc::{unbounded, UnboundedSender}; use futures::unsync::mpsc::{unbounded, UnboundedSender};
use futures::{Async, Future, Poll, Stream};
use http::header::{self, HeaderName, HeaderValue};
use http::{Error as HttpError, HttpTryFrom, StatusCode};
use rand;
use sha1::Sha1;
use actix::prelude::*; use actix::prelude::*;
use body::{Body, Binary}; use body::{Binary, Body};
use error::{Error, UrlParseError}; use error::{Error, UrlParseError};
use header::IntoHeaderValue; use header::IntoHeaderValue;
use payload::PayloadHelper;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use payload::PayloadHelper;
use client::{ClientRequest, ClientRequestBuilder, ClientResponse, use client::{ClientConnector, ClientRequest, ClientRequestBuilder, ClientResponse,
ClientConnector, SendRequest, SendRequestError, HttpResponseParserError, SendRequest, SendRequestError};
HttpResponseParserError};
use super::{Message, ProtocolError};
use super::frame::Frame; use super::frame::Frame;
use super::proto::{CloseCode, OpCode}; use super::proto::{CloseCode, OpCode};
use super::{Message, ProtocolError};
/// Websocket client error /// Websocket client error
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum ClientError { pub enum ClientError {
#[fail(display="Invalid url")] #[fail(display = "Invalid url")]
InvalidUrl, InvalidUrl,
#[fail(display="Invalid response status")] #[fail(display = "Invalid response status")]
InvalidResponseStatus(StatusCode), InvalidResponseStatus(StatusCode),
#[fail(display="Invalid upgrade header")] #[fail(display = "Invalid upgrade header")]
InvalidUpgradeHeader, InvalidUpgradeHeader,
#[fail(display="Invalid connection header")] #[fail(display = "Invalid connection header")]
InvalidConnectionHeader(HeaderValue), InvalidConnectionHeader(HeaderValue),
#[fail(display="Missing CONNECTION header")] #[fail(display = "Missing CONNECTION header")]
MissingConnectionHeader, MissingConnectionHeader,
#[fail(display="Missing SEC-WEBSOCKET-ACCEPT header")] #[fail(display = "Missing SEC-WEBSOCKET-ACCEPT header")]
MissingWebSocketAcceptHeader, MissingWebSocketAcceptHeader,
#[fail(display="Invalid challenge response")] #[fail(display = "Invalid challenge response")]
InvalidChallengeResponse(String, HeaderValue), InvalidChallengeResponse(String, HeaderValue),
#[fail(display="Http parsing error")] #[fail(display = "Http parsing error")]
Http(Error), Http(Error),
#[fail(display="Url parsing error")] #[fail(display = "Url parsing error")]
Url(UrlParseError), Url(UrlParseError),
#[fail(display="Response parsing error")] #[fail(display = "Response parsing error")]
ResponseParseError(HttpResponseParserError), ResponseParseError(HttpResponseParserError),
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
SendRequest(SendRequestError), SendRequest(SendRequestError),
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Protocol(#[cause] ProtocolError), Protocol(#[cause] ProtocolError),
#[fail(display="{}", _0)] #[fail(display = "{}", _0)]
Io(io::Error), Io(io::Error),
#[fail(display="Disconnected")] #[fail(display = "Disconnected")]
Disconnected, Disconnected,
} }
@ -117,14 +115,15 @@ pub struct Client {
} }
impl Client { impl Client {
/// Create new websocket connection /// Create new websocket connection
pub fn new<S: AsRef<str>>(uri: S) -> Client { pub fn new<S: AsRef<str>>(uri: S) -> Client {
Client::with_connector(uri, ClientConnector::from_registry()) Client::with_connector(uri, ClientConnector::from_registry())
} }
/// Create new websocket connection with custom `ClientConnector` /// Create new websocket connection with custom `ClientConnector`
pub fn with_connector<S: AsRef<str>>(uri: S, conn: Addr<Unsync, ClientConnector>) -> Client { pub fn with_connector<S: AsRef<str>>(
uri: S, conn: Addr<Unsync, ClientConnector>
) -> Client {
let mut cl = Client { let mut cl = Client {
request: ClientRequest::build(), request: ClientRequest::build(),
err: None, err: None,
@ -140,11 +139,13 @@ impl Client {
/// Set supported websocket protocols /// Set supported websocket protocols
pub fn protocols<U, V>(mut self, protos: U) -> Self pub fn protocols<U, V>(mut self, protos: U) -> Self
where U: IntoIterator<Item=V> + 'static, where
V: AsRef<str> U: IntoIterator<Item = V> + 'static,
V: AsRef<str>,
{ {
let mut protos = protos.into_iter() let mut protos = protos
.fold(String::new(), |acc, s| {acc + s.as_ref() + ","}); .into_iter()
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
protos.pop(); protos.pop();
self.protocols = Some(protos); self.protocols = Some(protos);
self self
@ -158,7 +159,8 @@ impl Client {
/// Set request Origin /// Set request Origin
pub fn origin<V>(mut self, origin: V) -> Self pub fn origin<V>(mut self, origin: V) -> Self
where HeaderValue: HttpTryFrom<V> where
HeaderValue: HttpTryFrom<V>,
{ {
match HeaderValue::try_from(origin) { match HeaderValue::try_from(origin) {
Ok(value) => self.origin = Some(value), Ok(value) => self.origin = Some(value),
@ -185,7 +187,9 @@ impl Client {
/// Set request header /// Set request header
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>, V: IntoHeaderValue where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{ {
self.request.header(key, value); self.request.header(key, value);
self self
@ -204,8 +208,7 @@ impl Client {
pub fn connect(&mut self) -> ClientHandshake { pub fn connect(&mut self) -> ClientHandshake {
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
ClientHandshake::error(e) ClientHandshake::error(e)
} } else if let Some(e) = self.http_err.take() {
else if let Some(e) = self.http_err.take() {
ClientHandshake::error(Error::from(e).into()) ClientHandshake::error(Error::from(e).into())
} else { } else {
// origin // origin
@ -216,11 +219,13 @@ impl Client {
self.request.upgrade(); self.request.upgrade();
self.request.set_header(header::UPGRADE, "websocket"); self.request.set_header(header::UPGRADE, "websocket");
self.request.set_header(header::CONNECTION, "upgrade"); self.request.set_header(header::CONNECTION, "upgrade");
self.request.set_header(header::SEC_WEBSOCKET_VERSION, "13"); self.request
.set_header(header::SEC_WEBSOCKET_VERSION, "13");
self.request.with_connector(self.conn.clone()); self.request.with_connector(self.conn.clone());
if let Some(protocols) = self.protocols.take() { if let Some(protocols) = self.protocols.take() {
self.request.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str()); self.request
.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str());
} }
let request = match self.request.finish() { let request = match self.request.finish() {
Ok(req) => req, Ok(req) => req,
@ -228,14 +233,16 @@ impl Client {
}; };
if request.uri().host().is_none() { if request.uri().host().is_none() {
return ClientHandshake::error(ClientError::InvalidUrl) return ClientHandshake::error(ClientError::InvalidUrl);
} }
if let Some(scheme) = request.uri().scheme_part() { if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { if scheme != "http" && scheme != "https" && scheme != "ws"
return ClientHandshake::error(ClientError::InvalidUrl) && scheme != "wss"
{
return ClientHandshake::error(ClientError::InvalidUrl);
} }
} else { } else {
return ClientHandshake::error(ClientError::InvalidUrl) return ClientHandshake::error(ClientError::InvalidUrl);
} }
// start handshake // start handshake
@ -263,8 +270,7 @@ pub struct ClientHandshake {
} }
impl ClientHandshake { impl ClientHandshake {
fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake {
{
// Generate a random key for the `Sec-WebSocket-Key` header. // Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that, // a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455) // when decoded, is 16 bytes in length (RFC 6455)
@ -273,12 +279,13 @@ impl ClientHandshake {
request.headers_mut().insert( request.headers_mut().insert(
header::SEC_WEBSOCKET_KEY, header::SEC_WEBSOCKET_KEY,
HeaderValue::try_from(key.as_str()).unwrap()); HeaderValue::try_from(key.as_str()).unwrap(),
);
let (tx, rx) = unbounded(); let (tx, rx) = unbounded();
request.set_body(Body::Streaming( request.set_body(Body::Streaming(Box::new(rx.map_err(|_| {
Box::new(rx.map_err(|_| io::Error::new( io::Error::new(io::ErrorKind::Other, "disconnected").into()
io::ErrorKind::Other, "disconnected").into())))); }))));
ClientHandshake { ClientHandshake {
key, key,
@ -329,20 +336,20 @@ impl Future for ClientHandshake {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.error.take() { if let Some(err) = self.error.take() {
return Err(err) return Err(err);
} }
let resp = match self.request.as_mut().unwrap().poll()? { let resp = match self.request.as_mut().unwrap().poll()? {
Async::Ready(response) => { Async::Ready(response) => {
self.request.take(); self.request.take();
response response
}, }
Async::NotReady => return Ok(Async::NotReady) Async::NotReady => return Ok(Async::NotReady),
}; };
// verify response // verify response
if resp.status() != StatusCode::SWITCHING_PROTOCOLS { if resp.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(ClientError::InvalidResponseStatus(resp.status())) return Err(ClientError::InvalidResponseStatus(resp.status()));
} }
// Check for "UPGRADE" to websocket header // Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) {
@ -356,26 +363,25 @@ impl Future for ClientHandshake {
}; };
if !has_hdr { if !has_hdr {
trace!("Invalid upgrade header"); trace!("Invalid upgrade header");
return Err(ClientError::InvalidUpgradeHeader) return Err(ClientError::InvalidUpgradeHeader);
} }
// Check for "CONNECTION" header // Check for "CONNECTION" header
if let Some(conn) = resp.headers().get(header::CONNECTION) { if let Some(conn) = resp.headers().get(header::CONNECTION) {
if let Ok(s) = conn.to_str() { if let Ok(s) = conn.to_str() {
if !s.to_lowercase().contains("upgrade") { if !s.to_lowercase().contains("upgrade") {
trace!("Invalid connection header: {}", s); trace!("Invalid connection header: {}", s);
return Err(ClientError::InvalidConnectionHeader(conn.clone())) return Err(ClientError::InvalidConnectionHeader(conn.clone()));
} }
} else { } else {
trace!("Invalid connection header: {:?}", conn); trace!("Invalid connection header: {:?}", conn);
return Err(ClientError::InvalidConnectionHeader(conn.clone())) return Err(ClientError::InvalidConnectionHeader(conn.clone()));
} }
} else { } else {
trace!("Missing connection header"); trace!("Missing connection header");
return Err(ClientError::MissingConnectionHeader) return Err(ClientError::MissingConnectionHeader);
} }
if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT) if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT) {
{
// field is constructed by concatenating /key/ // field is constructed by concatenating /key/
// with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@ -386,12 +392,17 @@ impl Future for ClientHandshake {
if key.as_bytes() != encoded.as_bytes() { if key.as_bytes() != encoded.as_bytes() {
trace!( trace!(
"Invalid challenge response: expected: {} received: {:?}", "Invalid challenge response: expected: {} received: {:?}",
encoded, key); encoded,
return Err(ClientError::InvalidChallengeResponse(encoded, key.clone())); key
);
return Err(ClientError::InvalidChallengeResponse(
encoded,
key.clone(),
));
} }
} else { } else {
trace!("Missing SEC-WEBSOCKET-ACCEPT header"); trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(ClientError::MissingWebSocketAcceptHeader) return Err(ClientError::MissingWebSocketAcceptHeader);
}; };
let inner = Inner { let inner = Inner {
@ -401,13 +412,16 @@ impl Future for ClientHandshake {
}; };
let inner = Rc::new(UnsafeCell::new(inner)); let inner = Rc::new(UnsafeCell::new(inner));
Ok(Async::Ready( Ok(Async::Ready((
(ClientReader{inner: Rc::clone(&inner), max_size: self.max_size}, ClientReader {
ClientWriter{inner}))) inner: Rc::clone(&inner),
max_size: self.max_size,
},
ClientWriter { inner },
)))
} }
} }
pub struct ClientReader { pub struct ClientReader {
inner: Rc<UnsafeCell<Inner>>, inner: Rc<UnsafeCell<Inner>>,
max_size: usize, max_size: usize,
@ -422,7 +436,7 @@ impl fmt::Debug for ClientReader {
impl ClientReader { impl ClientReader {
#[inline] #[inline]
fn as_mut(&mut self) -> &mut Inner { fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() } unsafe { &mut *self.inner.get() }
} }
} }
@ -434,7 +448,7 @@ impl Stream for ClientReader {
let max_size = self.max_size; let max_size = self.max_size;
let inner = self.as_mut(); let inner = self.as_mut();
if inner.closed { if inner.closed {
return Ok(Async::Ready(None)) return Ok(Async::Ready(None));
} }
// read // read
@ -447,31 +461,29 @@ impl Stream for ClientReader {
OpCode::Continue => { OpCode::Continue => {
inner.closed = true; inner.closed = true;
Err(ProtocolError::NoContinuation) Err(ProtocolError::NoContinuation)
}, }
OpCode::Bad => { OpCode::Bad => {
inner.closed = true; inner.closed = true;
Err(ProtocolError::BadOpCode) Err(ProtocolError::BadOpCode)
}, }
OpCode::Close => { OpCode::Close => {
inner.closed = true; inner.closed = true;
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
Ok(Async::Ready(Some(Message::Close(CloseCode::from(code))))) Ok(Async::Ready(Some(Message::Close(CloseCode::from(
}, code,
OpCode::Ping => )))))
Ok(Async::Ready(Some( }
Message::Ping( OpCode::Ping => Ok(Async::Ready(Some(Message::Ping(
String::from_utf8_lossy(payload.as_ref()).into())))), String::from_utf8_lossy(payload.as_ref()).into(),
OpCode::Pong => )))),
Ok(Async::Ready(Some( OpCode::Pong => Ok(Async::Ready(Some(Message::Pong(
Message::Pong( String::from_utf8_lossy(payload.as_ref()).into(),
String::from_utf8_lossy(payload.as_ref()).into())))), )))),
OpCode::Binary => OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))),
Ok(Async::Ready(Some(Message::Binary(payload)))),
OpCode::Text => { OpCode::Text => {
let tmp = Vec::from(payload.as_ref()); let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) { match String::from_utf8(tmp) {
Ok(s) => Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))),
Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => { Err(_) => {
inner.closed = true; inner.closed = true;
Err(ProtocolError::BadEncoding) Err(ProtocolError::BadEncoding)
@ -491,18 +503,17 @@ impl Stream for ClientReader {
} }
pub struct ClientWriter { pub struct ClientWriter {
inner: Rc<UnsafeCell<Inner>> inner: Rc<UnsafeCell<Inner>>,
} }
impl ClientWriter { impl ClientWriter {
#[inline] #[inline]
fn as_mut(&mut self) -> &mut Inner { fn as_mut(&mut self) -> &mut Inner {
unsafe{ &mut *self.inner.get() } unsafe { &mut *self.inner.get() }
} }
} }
impl ClientWriter { impl ClientWriter {
/// Write payload /// Write payload
#[inline] #[inline]
fn write(&mut self, mut data: Binary) { fn write(&mut self, mut data: Binary) {
@ -528,13 +539,23 @@ impl ClientWriter {
/// Send ping frame /// Send ping frame
#[inline] #[inline]
pub fn ping(&mut self, message: &str) { pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true, true)); self.write(Frame::message(
Vec::from(message),
OpCode::Ping,
true,
true,
));
} }
/// Send pong frame /// Send pong frame
#[inline] #[inline]
pub fn pong(&mut self, message: &str) { pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true, true)); self.write(Frame::message(
Vec::from(message),
OpCode::Pong,
true,
true,
));
} }
/// Send close frame /// Send close frame

View File

@ -1,25 +1,26 @@
use std::mem;
use futures::{Async, Poll};
use futures::sync::oneshot::Sender; use futures::sync::oneshot::Sender;
use futures::unsync::oneshot; use futures::unsync::oneshot;
use futures::{Async, Poll};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::mem;
use actix::{Actor, ActorState, ActorContext, AsyncContext, use actix::dev::{ContextImpl, SyncEnvelope, ToEnvelope};
Addr, Handler, Message, Syn, Unsync, SpawnHandle};
use actix::fut::ActorFuture; use actix::fut::ActorFuture;
use actix::dev::{ContextImpl, ToEnvelope, SyncEnvelope}; use actix::{Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message,
SpawnHandle, Syn, Unsync};
use body::{Body, Binary}; use body::{Binary, Body};
use context::{ActorHttpContext, Drain, Frame as ContextFrame};
use error::{Error, ErrorInternalServerError}; use error::{Error, ErrorInternalServerError};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use context::{Frame as ContextFrame, ActorHttpContext, Drain};
use ws::frame::Frame; use ws::frame::Frame;
use ws::proto::{OpCode, CloseCode}; use ws::proto::{CloseCode, OpCode};
/// Execution context for `WebSockets` actors /// Execution context for `WebSockets` actors
pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A, S>>, pub struct WebsocketContext<A, S = ()>
where
A: Actor<Context = WebsocketContext<A, S>>,
{ {
inner: ContextImpl<A>, inner: ContextImpl<A>,
stream: Option<SmallVec<[ContextFrame; 4]>>, stream: Option<SmallVec<[ContextFrame; 4]>>,
@ -27,7 +28,9 @@ pub struct WebsocketContext<A, S=()> where A: Actor<Context=WebsocketContext<A,
disconnected: bool, disconnected: bool,
} }
impl<A, S> ActorContext for WebsocketContext<A, S> where A: Actor<Context=Self> impl<A, S> ActorContext for WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{ {
fn stop(&mut self) { fn stop(&mut self) {
self.inner.stop(); self.inner.stop();
@ -40,16 +43,20 @@ impl<A, S> ActorContext for WebsocketContext<A, S> where A: Actor<Context=Self>
} }
} }
impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Self> impl<A, S> AsyncContext<A> for WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{ {
fn spawn<F>(&mut self, fut: F) -> SpawnHandle fn spawn<F>(&mut self, fut: F) -> SpawnHandle
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{ {
self.inner.spawn(fut) self.inner.spawn(fut)
} }
fn wait<F>(&mut self, fut: F) fn wait<F>(&mut self, fut: F)
where F: ActorFuture<Item=(), Error=(), Actor=A> + 'static where
F: ActorFuture<Item = (), Error = (), Actor = A> + 'static,
{ {
self.inner.wait(fut) self.inner.wait(fut)
} }
@ -57,8 +64,8 @@ impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Sel
#[doc(hidden)] #[doc(hidden)]
#[inline] #[inline]
fn waiting(&self) -> bool { fn waiting(&self) -> bool {
self.inner.waiting() || self.inner.state() == ActorState::Stopping || self.inner.waiting() || self.inner.state() == ActorState::Stopping
self.inner.state() == ActorState::Stopped || self.inner.state() == ActorState::Stopped
} }
fn cancel_future(&mut self, handle: SpawnHandle) -> bool { fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
@ -78,8 +85,10 @@ impl<A, S> AsyncContext<A> for WebsocketContext<A, S> where A: Actor<Context=Sel
} }
} }
impl<A, S: 'static> WebsocketContext<A, S> where A: Actor<Context=Self> { impl<A, S: 'static> WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{
#[inline] #[inline]
pub fn new(req: HttpRequest<S>, actor: A) -> WebsocketContext<A, S> { pub fn new(req: HttpRequest<S>, actor: A) -> WebsocketContext<A, S> {
WebsocketContext::from_request(req).actor(actor) WebsocketContext::from_request(req).actor(actor)
@ -101,8 +110,10 @@ impl<A, S: 'static> WebsocketContext<A, S> where A: Actor<Context=Self> {
} }
} }
impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> { impl<A, S> WebsocketContext<A, S>
where
A: Actor<Context = Self>,
{
/// Write payload /// Write payload
#[inline] #[inline]
fn write(&mut self, data: Binary) { fn write(&mut self, data: Binary) {
@ -145,13 +156,23 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Send ping frame /// Send ping frame
#[inline] #[inline]
pub fn ping(&mut self, message: &str) { pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true, false)); self.write(Frame::message(
Vec::from(message),
OpCode::Ping,
true,
false,
));
} }
/// Send pong frame /// Send pong frame
#[inline] #[inline]
pub fn pong(&mut self, message: &str) { pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true, false)); self.write(Frame::message(
Vec::from(message),
OpCode::Pong,
true,
false,
));
} }
/// Send close frame /// Send close frame
@ -191,8 +212,11 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
} }
} }
impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Self>, S: 'static { impl<A, S> ActorHttpContext for WebsocketContext<A, S>
where
A: Actor<Context = Self>,
S: 'static,
{
#[inline] #[inline]
fn disconnected(&mut self) { fn disconnected(&mut self) {
self.disconnected = true; self.disconnected = true;
@ -200,12 +224,11 @@ impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Se
} }
fn poll(&mut self) -> Poll<Option<SmallVec<[ContextFrame; 4]>>, Error> { fn poll(&mut self) -> Poll<Option<SmallVec<[ContextFrame; 4]>>, Error> {
let ctx: &mut WebsocketContext<A, S> = unsafe { let ctx: &mut WebsocketContext<A, S> =
mem::transmute(self as &mut WebsocketContext<A, S>) unsafe { mem::transmute(self as &mut WebsocketContext<A, S>) };
};
if self.inner.alive() && self.inner.poll(ctx).is_err() { if self.inner.alive() && self.inner.poll(ctx).is_err() {
return Err(ErrorInternalServerError("error")) return Err(ErrorInternalServerError("error"));
} }
// frames // frames
@ -220,8 +243,10 @@ impl<A, S> ActorHttpContext for WebsocketContext<A, S> where A: Actor<Context=Se
} }
impl<A, M, S> ToEnvelope<Syn, A, M> for WebsocketContext<A, S> impl<A, M, S> ToEnvelope<Syn, A, M> for WebsocketContext<A, S>
where A: Actor<Context=WebsocketContext<A, S>> + Handler<M>, where
M: Message + Send + 'static, M::Result: Send A: Actor<Context = WebsocketContext<A, S>> + Handler<M>,
M: Message + Send + 'static,
M::Result: Send,
{ {
fn pack(msg: M, tx: Option<Sender<M::Result>>) -> SyncEnvelope<A> { fn pack(msg: M, tx: Option<Sender<M::Result>>) -> SyncEnvelope<A> {
SyncEnvelope::new(msg, tx) SyncEnvelope::new(msg, tx)
@ -229,7 +254,9 @@ impl<A, M, S> ToEnvelope<Syn, A, M> for WebsocketContext<A, S>
} }
impl<A, S> From<WebsocketContext<A, S>> for Body impl<A, S> From<WebsocketContext<A, S>> for Body
where A: Actor<Context=WebsocketContext<A, S>>, S: 'static where
A: Actor<Context = WebsocketContext<A, S>>,
S: 'static,
{ {
fn from(ctx: WebsocketContext<A, S>) -> Body { fn from(ctx: WebsocketContext<A, S>) -> Body {
Body::Actor(Box::new(ctx)) Body::Actor(Box::new(ctx))

View File

@ -1,17 +1,17 @@
use std::{fmt, mem, ptr}; use byteorder::{BigEndian, ByteOrder, NetworkEndian};
use std::iter::FromIterator; use bytes::{BufMut, Bytes, BytesMut};
use bytes::{Bytes, BytesMut, BufMut};
use byteorder::{ByteOrder, BigEndian, NetworkEndian};
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use rand; use rand;
use std::iter::FromIterator;
use std::{fmt, mem, ptr};
use body::Binary; use body::Binary;
use error::{PayloadError}; use error::PayloadError;
use payload::PayloadHelper; use payload::PayloadHelper;
use ws::ProtocolError; use ws::ProtocolError;
use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask; use ws::mask::apply_mask;
use ws::proto::{CloseCode, OpCode};
/// A struct representing a `WebSocket` frame. /// A struct representing a `WebSocket` frame.
#[derive(Debug)] #[derive(Debug)]
@ -22,7 +22,6 @@ pub struct Frame {
} }
impl Frame { impl Frame {
/// Destruct frame /// Destruct frame
pub fn unpack(self) -> (bool, OpCode, Binary) { pub fn unpack(self) -> (bool, OpCode, Binary) {
(self.finished, self.opcode, self.payload) (self.finished, self.opcode, self.payload)
@ -40,20 +39,22 @@ impl Frame {
Vec::new() Vec::new()
} else { } else {
Vec::from_iter( Vec::from_iter(
raw[..].iter() raw[..]
.iter()
.chain(reason.as_bytes().iter()) .chain(reason.as_bytes().iter())
.cloned()) .cloned(),
)
}; };
Frame::message(payload, OpCode::Close, true, genmask) Frame::message(payload, OpCode::Close, true, genmask)
} }
#[cfg_attr(feature="cargo-clippy", allow(type_complexity))] #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))]
fn read_copy_md<S>(pl: &mut PayloadHelper<S>, fn read_copy_md<S>(
server: bool, pl: &mut PayloadHelper<S>, server: bool, max_size: usize
max_size: usize
) -> Poll<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> ) -> Poll<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError>
where S: Stream<Item=Bytes, Error=PayloadError> where
S: Stream<Item = Bytes, Error = PayloadError>,
{ {
let mut idx = 2; let mut idx = 2;
let buf = match pl.copy(2)? { let buf = match pl.copy(2)? {
@ -68,16 +69,16 @@ impl Frame {
// check masking // check masking
let masked = second & 0x80 != 0; let masked = second & 0x80 != 0;
if !masked && server { if !masked && server {
return Err(ProtocolError::UnmaskedFrame) return Err(ProtocolError::UnmaskedFrame);
} else if masked && !server { } else if masked && !server {
return Err(ProtocolError::MaskedFrame) return Err(ProtocolError::MaskedFrame);
} }
// Op code // Op code
let opcode = OpCode::from(first & 0x0F); let opcode = OpCode::from(first & 0x0F);
if let OpCode::Bad = opcode { if let OpCode::Bad = opcode {
return Err(ProtocolError::InvalidOpcode(first & 0x0F)) return Err(ProtocolError::InvalidOpcode(first & 0x0F));
} }
let len = second & 0x7F; let len = second & 0x7F;
@ -105,7 +106,7 @@ impl Frame {
// check for max allowed size // check for max allowed size
if length > max_size { if length > max_size {
return Err(ProtocolError::Overflow) return Err(ProtocolError::Overflow);
} }
let mask = if server { let mask = if server {
@ -115,25 +116,32 @@ impl Frame {
Async::NotReady => return Ok(Async::NotReady), Async::NotReady => return Ok(Async::NotReady),
}; };
let mask: &[u8] = &buf[idx..idx+4]; let mask: &[u8] = &buf[idx..idx + 4];
let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; let mask_u32: u32 =
unsafe { ptr::read_unaligned(mask.as_ptr() as *const u32) };
idx += 4; idx += 4;
Some(mask_u32) Some(mask_u32)
} else { } else {
None None
}; };
Ok(Async::Ready(Some((idx, finished, opcode, length, mask)))) Ok(Async::Ready(Some((
idx,
finished,
opcode,
length,
mask,
))))
} }
fn read_chunk_md(chunk: &[u8], server: bool, max_size: usize) fn read_chunk_md(
-> Poll<(usize, bool, OpCode, usize, Option<u32>), ProtocolError> chunk: &[u8], server: bool, max_size: usize
{ ) -> Poll<(usize, bool, OpCode, usize, Option<u32>), ProtocolError> {
let chunk_len = chunk.len(); let chunk_len = chunk.len();
let mut idx = 2; let mut idx = 2;
if chunk_len < 2 { if chunk_len < 2 {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
let first = chunk[0]; let first = chunk[0];
@ -143,29 +151,29 @@ impl Frame {
// check masking // check masking
let masked = second & 0x80 != 0; let masked = second & 0x80 != 0;
if !masked && server { if !masked && server {
return Err(ProtocolError::UnmaskedFrame) return Err(ProtocolError::UnmaskedFrame);
} else if masked && !server { } else if masked && !server {
return Err(ProtocolError::MaskedFrame) return Err(ProtocolError::MaskedFrame);
} }
// Op code // Op code
let opcode = OpCode::from(first & 0x0F); let opcode = OpCode::from(first & 0x0F);
if let OpCode::Bad = opcode { if let OpCode::Bad = opcode {
return Err(ProtocolError::InvalidOpcode(first & 0x0F)) return Err(ProtocolError::InvalidOpcode(first & 0x0F));
} }
let len = second & 0x7F; let len = second & 0x7F;
let length = if len == 126 { let length = if len == 126 {
if chunk_len < 4 { if chunk_len < 4 {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize; let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize;
idx += 2; idx += 2;
len len
} else if len == 127 { } else if len == 127 {
if chunk_len < 10 { if chunk_len < 10 {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
let len = NetworkEndian::read_uint(&chunk[idx..], 8) as usize; let len = NetworkEndian::read_uint(&chunk[idx..], 8) as usize;
idx += 8; idx += 8;
@ -176,16 +184,17 @@ impl Frame {
// check for max allowed size // check for max allowed size
if length > max_size { if length > max_size {
return Err(ProtocolError::Overflow) return Err(ProtocolError::Overflow);
} }
let mask = if server { let mask = if server {
if chunk_len < idx + 4 { if chunk_len < idx + 4 {
return Ok(Async::NotReady) return Ok(Async::NotReady);
} }
let mask: &[u8] = &chunk[idx..idx+4]; let mask: &[u8] = &chunk[idx..idx + 4];
let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; let mask_u32: u32 =
unsafe { ptr::read_unaligned(mask.as_ptr() as *const u32) };
idx += 4; idx += 4;
Some(mask_u32) Some(mask_u32)
} else { } else {
@ -196,9 +205,11 @@ impl Frame {
} }
/// Parse the input stream into a frame. /// Parse the input stream into a frame.
pub fn parse<S>(pl: &mut PayloadHelper<S>, server: bool, max_size: usize) pub fn parse<S>(
-> Poll<Option<Frame>, ProtocolError> pl: &mut PayloadHelper<S>, server: bool, max_size: usize
where S: Stream<Item=Bytes, Error=PayloadError> ) -> Poll<Option<Frame>, ProtocolError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{ {
// try to parse ws frame md from one chunk // try to parse ws frame md from one chunk
let result = match pl.get_chunk()? { let result = match pl.get_chunk()? {
@ -229,7 +240,10 @@ impl Frame {
// no need for body // no need for body
if length == 0 { if length == 0 {
return Ok(Async::Ready(Some(Frame { return Ok(Async::Ready(Some(Frame {
finished, opcode, payload: Binary::from("") }))); finished,
opcode,
payload: Binary::from(""),
})));
} }
let data = match pl.read_exact(length)? { let data = match pl.read_exact(length)? {
@ -245,26 +259,32 @@ impl Frame {
} }
OpCode::Close if length > 125 => { OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
return Ok(Async::Ready(Some(Frame::default()))) return Ok(Async::Ready(Some(Frame::default())));
} }
_ => () _ => (),
} }
// unmask // unmask
if let Some(mask) = mask { if let Some(mask) = mask {
#[allow(mutable_transmutes)] #[allow(mutable_transmutes)]
let p: &mut [u8] = unsafe{let ptr: &[u8] = &data; mem::transmute(ptr)}; let p: &mut [u8] = unsafe {
let ptr: &[u8] = &data;
mem::transmute(ptr)
};
apply_mask(p, mask); apply_mask(p, mask);
} }
Ok(Async::Ready(Some(Frame { Ok(Async::Ready(Some(Frame {
finished, opcode, payload: data.into() }))) finished,
opcode,
payload: data.into(),
})))
} }
/// Generate binary representation /// Generate binary representation
pub fn message<B: Into<Binary>>(data: B, code: OpCode, pub fn message<B: Into<Binary>>(
finished: bool, genmask: bool) -> Binary data: B, code: OpCode, finished: bool, genmask: bool
{ ) -> Binary {
let payload = data.into(); let payload = data.into();
let one: u8 = if finished { let one: u8 = if finished {
0x80 | Into::<u8>::into(code) 0x80 | Into::<u8>::into(code)
@ -286,19 +306,19 @@ impl Frame {
let mut buf = BytesMut::with_capacity(p_len + 4); let mut buf = BytesMut::with_capacity(p_len + 4);
buf.put_slice(&[one, two | 126]); buf.put_slice(&[one, two | 126]);
{ {
let buf_mut = unsafe{buf.bytes_mut()}; let buf_mut = unsafe { buf.bytes_mut() };
BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16); BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16);
} }
unsafe{buf.advance_mut(2)}; unsafe { buf.advance_mut(2) };
buf buf
} else { } else {
let mut buf = BytesMut::with_capacity(p_len + 10); let mut buf = BytesMut::with_capacity(p_len + 10);
buf.put_slice(&[one, two | 127]); buf.put_slice(&[one, two | 127]);
{ {
let buf_mut = unsafe{buf.bytes_mut()}; let buf_mut = unsafe { buf.bytes_mut() };
BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64); BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64);
} }
unsafe{buf.advance_mut(8)}; unsafe { buf.advance_mut(8) };
buf buf
}; };
@ -308,7 +328,7 @@ impl Frame {
{ {
let buf_mut = buf.bytes_mut(); let buf_mut = buf.bytes_mut();
*(buf_mut as *mut _ as *mut u32) = mask; *(buf_mut as *mut _ as *mut u32) = mask;
buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref()); buf_mut[4..payload_len + 4].copy_from_slice(payload.as_ref());
apply_mask(&mut buf_mut[4..], mask); apply_mask(&mut buf_mut[4..], mask);
} }
buf.advance_mut(payload_len + 4); buf.advance_mut(payload_len + 4);
@ -333,7 +353,8 @@ impl Default for Frame {
impl fmt::Display for Frame { impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, write!(
f,
" "
<FRAME> <FRAME>
final: {} final: {}
@ -341,11 +362,15 @@ impl fmt::Display for Frame {
payload length: {} payload length: {}
payload: 0x{} payload: 0x{}
</FRAME>", </FRAME>",
self.finished, self.finished,
self.opcode, self.opcode,
self.payload.len(), self.payload.len(),
self.payload.as_ref().iter().map( self.payload
|byte| format!("{:x}", byte)).collect::<String>()) .as_ref()
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
)
} }
} }
@ -360,7 +385,7 @@ mod tests {
_ => false, _ => false,
} }
} }
fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame { fn extract(frm: Poll<Option<Frame>, ProtocolError>) -> Frame {
match frm { match frm {
Ok(Async::Ready(Some(frame))) => frame, Ok(Async::Ready(Some(frame))) => frame,
@ -370,8 +395,9 @@ mod tests {
#[test] #[test]
fn test_parse() { fn test_parse() {
let mut buf = PayloadHelper::new( let mut buf = PayloadHelper::new(once(Ok(BytesMut::from(
once(Ok(BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]).freeze()))); &[0b0000_0001u8, 0b0000_0001u8][..],
).freeze())));
assert!(is_none(&Frame::parse(&mut buf, false, 1024))); assert!(is_none(&Frame::parse(&mut buf, false, 1024)));
let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);

View File

@ -20,7 +20,7 @@ fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) {
/// Faster version of `apply_mask()` which operates on 8-byte blocks. /// Faster version of `apply_mask()` which operates on 8-byte blocks.
#[inline] #[inline]
#[cfg_attr(feature="cargo-clippy", allow(cast_lossless))] #[cfg_attr(feature = "cargo-clippy", allow(cast_lossless))]
fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) { fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) {
let mut ptr = buf.as_mut_ptr(); let mut ptr = buf.as_mut_ptr();
let mut len = buf.len(); let mut len = buf.len();
@ -85,13 +85,16 @@ fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) {
// Possible last block. // Possible last block.
if len > 0 { if len > 0 {
unsafe { xor_mem(ptr, mask_u32, len); } unsafe {
xor_mem(ptr, mask_u32, len);
}
} }
} }
#[inline] #[inline]
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient, // TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so
// it could be done better. The compiler does not see that len is limited to 3. // inefficient, it could be done better. The compiler does not see that len is
// limited to 3.
unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) { unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
let mut b: u32 = uninitialized(); let mut b: u32 = uninitialized();
#[allow(trivial_casts)] #[allow(trivial_casts)]
@ -103,19 +106,17 @@ unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{apply_mask_fallback, apply_mask_fast32};
use std::ptr; use std::ptr;
use super::{apply_mask_fallback, apply_mask_fast32};
#[test] #[test]
fn test_apply_mask() { fn test_apply_mask() {
let mask = [ let mask = [0x6d, 0xb6, 0xb2, 0x80];
0x6d, 0xb6, 0xb2, 0x80, let mask_u32: u32 = unsafe { ptr::read_unaligned(mask.as_ptr() as *const u32) };
];
let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)};
let unmasked = vec![ let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17,
0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03, 0x74, 0xf9, 0x12, 0x03,
]; ];
// Check masking with proper alignment. // Check masking with proper alignment.

View File

@ -1,7 +1,8 @@
//! `WebSocket` support for Actix //! `WebSocket` support for Actix
//! //!
//! To setup a `WebSocket`, first do web socket handshake then on success convert `Payload` //! To setup a `WebSocket`, first do web socket handshake then on success
//! into a `WsStream` stream and then use `WsWriter` to communicate with the peer. //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
//! communicate with the peer.
//! //!
//! ## Example //! ## Example
//! //!
@ -42,62 +43,61 @@
//! # .finish(); //! # .finish();
//! # } //! # }
//! ``` //! ```
use bytes::Bytes;
use http::{Method, StatusCode, header};
use futures::{Async, Poll, Stream};
use byteorder::{ByteOrder, NetworkEndian}; use byteorder::{ByteOrder, NetworkEndian};
use bytes::Bytes;
use futures::{Async, Poll, Stream};
use http::{header, Method, StatusCode};
use actix::{Actor, AsyncContext, StreamHandler}; use actix::{Actor, AsyncContext, StreamHandler};
use body::Binary; use body::Binary;
use payload::PayloadHelper;
use error::{Error, PayloadError, ResponseError}; use error::{Error, PayloadError, ResponseError};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use httprequest::HttpRequest; use httprequest::HttpRequest;
use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder};
use payload::PayloadHelper;
mod frame;
mod proto;
mod context;
mod mask;
mod client; mod client;
mod context;
mod frame;
mod mask;
mod proto;
pub use self::frame::Frame; pub use self::client::{Client, ClientError, ClientHandshake, ClientReader, ClientWriter};
pub use self::proto::OpCode;
pub use self::proto::CloseCode;
pub use self::context::WebsocketContext; pub use self::context::WebsocketContext;
pub use self::client::{Client, ClientError, pub use self::frame::Frame;
ClientReader, ClientWriter, ClientHandshake}; pub use self::proto::CloseCode;
pub use self::proto::OpCode;
/// Websocket protocol errors /// Websocket protocol errors
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
pub enum ProtocolError { pub enum ProtocolError {
/// Received an unmasked frame from client /// Received an unmasked frame from client
#[fail(display="Received an unmasked frame from client")] #[fail(display = "Received an unmasked frame from client")]
UnmaskedFrame, UnmaskedFrame,
/// Received a masked frame from server /// Received a masked frame from server
#[fail(display="Received a masked frame from server")] #[fail(display = "Received a masked frame from server")]
MaskedFrame, MaskedFrame,
/// Encountered invalid opcode /// Encountered invalid opcode
#[fail(display="Invalid opcode: {}", _0)] #[fail(display = "Invalid opcode: {}", _0)]
InvalidOpcode(u8), InvalidOpcode(u8),
/// Invalid control frame length /// Invalid control frame length
#[fail(display="Invalid control frame length: {}", _0)] #[fail(display = "Invalid control frame length: {}", _0)]
InvalidLength(usize), InvalidLength(usize),
/// Bad web socket op code /// Bad web socket op code
#[fail(display="Bad web socket op code")] #[fail(display = "Bad web socket op code")]
BadOpCode, BadOpCode,
/// A payload reached size limit. /// A payload reached size limit.
#[fail(display="A payload reached size limit.")] #[fail(display = "A payload reached size limit.")]
Overflow, Overflow,
/// Continuation is not supported /// Continuation is not supported
#[fail(display="Continuation is not supported.")] #[fail(display = "Continuation is not supported.")]
NoContinuation, NoContinuation,
/// Bad utf-8 encoding /// Bad utf-8 encoding
#[fail(display="Bad utf-8 encoding.")] #[fail(display = "Bad utf-8 encoding.")]
BadEncoding, BadEncoding,
/// Payload error /// Payload error
#[fail(display="Payload error: {}", _0)] #[fail(display = "Payload error: {}", _0)]
Payload(#[cause] PayloadError), Payload(#[cause] PayloadError),
} }
@ -113,42 +113,46 @@ impl From<PayloadError> for ProtocolError {
#[derive(Fail, PartialEq, Debug)] #[derive(Fail, PartialEq, Debug)]
pub enum HandshakeError { pub enum HandshakeError {
/// Only get method is allowed /// Only get method is allowed
#[fail(display="Method not allowed")] #[fail(display = "Method not allowed")]
GetMethodRequired, GetMethodRequired,
/// Upgrade header if not set to websocket /// Upgrade header if not set to websocket
#[fail(display="Websocket upgrade is expected")] #[fail(display = "Websocket upgrade is expected")]
NoWebsocketUpgrade, NoWebsocketUpgrade,
/// Connection header is not set to upgrade /// Connection header is not set to upgrade
#[fail(display="Connection upgrade is expected")] #[fail(display = "Connection upgrade is expected")]
NoConnectionUpgrade, NoConnectionUpgrade,
/// Websocket version header is not set /// Websocket version header is not set
#[fail(display="Websocket version header is required")] #[fail(display = "Websocket version header is required")]
NoVersionHeader, NoVersionHeader,
/// Unsupported websocket version /// Unsupported websocket version
#[fail(display="Unsupported version")] #[fail(display = "Unsupported version")]
UnsupportedVersion, UnsupportedVersion,
/// Websocket key is not set or wrong /// Websocket key is not set or wrong
#[fail(display="Unknown websocket key")] #[fail(display = "Unknown websocket key")]
BadWebsocketKey, BadWebsocketKey,
} }
impl ResponseError for HandshakeError { impl ResponseError for HandshakeError {
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
match *self { match *self {
HandshakeError::GetMethodRequired => { HandshakeError::GetMethodRequired => HttpResponse::MethodNotAllowed()
HttpResponse::MethodNotAllowed().header(header::ALLOW, "GET").finish() .header(header::ALLOW, "GET")
} .finish(),
HandshakeError::NoWebsocketUpgrade => HttpResponse::BadRequest() HandshakeError::NoWebsocketUpgrade => HttpResponse::BadRequest()
.reason("No WebSocket UPGRADE header found").finish(), .reason("No WebSocket UPGRADE header found")
.finish(),
HandshakeError::NoConnectionUpgrade => HttpResponse::BadRequest() HandshakeError::NoConnectionUpgrade => HttpResponse::BadRequest()
.reason("No CONNECTION upgrade").finish(), .reason("No CONNECTION upgrade")
.finish(),
HandshakeError::NoVersionHeader => HttpResponse::BadRequest() HandshakeError::NoVersionHeader => HttpResponse::BadRequest()
.reason("Websocket version header is required").finish(), .reason("Websocket version header is required")
.finish(),
HandshakeError::UnsupportedVersion => HttpResponse::BadRequest() HandshakeError::UnsupportedVersion => HttpResponse::BadRequest()
.reason("Unsupported version").finish(), .reason("Unsupported version")
.finish(),
HandshakeError::BadWebsocketKey => HttpResponse::BadRequest() HandshakeError::BadWebsocketKey => HttpResponse::BadRequest()
.reason("Handshake error").finish(), .reason("Handshake error")
.finish(),
} }
} }
} }
@ -165,8 +169,9 @@ pub enum Message {
/// Do websocket handshake and start actor /// Do websocket handshake and start actor
pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error> pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
where A: Actor<Context=WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>, where
S: 'static A: Actor<Context = WebsocketContext<A, S>> + StreamHandler<Message, ProtocolError>,
S: 'static,
{ {
let mut resp = handshake(&req)?; let mut resp = handshake(&req)?;
let stream = WsStream::new(req.clone()); let stream = WsStream::new(req.clone());
@ -185,10 +190,12 @@ pub fn start<A, S>(req: HttpRequest<S>, actor: A) -> Result<HttpResponse, Error>
// /// `protocols` is a sequence of known protocols. On successful handshake, // /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list // /// the returned response headers contain the first protocol in this list
// /// which the server also knows. // /// which the server also knows.
pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, HandshakeError> { pub fn handshake<S>(
req: &HttpRequest<S>
) -> Result<HttpResponseBuilder, HandshakeError> {
// WebSocket accepts only GET // WebSocket accepts only GET
if *req.method() != Method::GET { if *req.method() != Method::GET {
return Err(HandshakeError::GetMethodRequired) return Err(HandshakeError::GetMethodRequired);
} }
// Check for "UPGRADE" to websocket header // Check for "UPGRADE" to websocket header
@ -202,17 +209,19 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, Handsha
false false
}; };
if !has_hdr { if !has_hdr {
return Err(HandshakeError::NoWebsocketUpgrade) return Err(HandshakeError::NoWebsocketUpgrade);
} }
// Upgrade connection // Upgrade connection
if !req.upgrade() { if !req.upgrade() {
return Err(HandshakeError::NoConnectionUpgrade) return Err(HandshakeError::NoConnectionUpgrade);
} }
// check supported version // check supported version
if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { if !req.headers()
return Err(HandshakeError::NoVersionHeader) .contains_key(header::SEC_WEBSOCKET_VERSION)
{
return Err(HandshakeError::NoVersionHeader);
} }
let supported_ver = { let supported_ver = {
if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
@ -222,12 +231,12 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, Handsha
} }
}; };
if !supported_ver { if !supported_ver {
return Err(HandshakeError::UnsupportedVersion) return Err(HandshakeError::UnsupportedVersion);
} }
// check client handshake for validity // check client handshake for validity
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
return Err(HandshakeError::BadWebsocketKey) return Err(HandshakeError::BadWebsocketKey);
} }
let key = { let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
@ -235,11 +244,11 @@ pub fn handshake<S>(req: &HttpRequest<S>) -> Result<HttpResponseBuilder, Handsha
}; };
Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
.connection_type(ConnectionType::Upgrade) .connection_type(ConnectionType::Upgrade)
.header(header::UPGRADE, "websocket") .header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked") .header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take()) .take())
} }
/// Maps `Payload` stream into stream of `ws::Message` items /// Maps `Payload` stream into stream of `ws::Message` items
@ -249,12 +258,16 @@ pub struct WsStream<S> {
max_size: usize, max_size: usize,
} }
impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> WsStream<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create new websocket frames stream /// Create new websocket frames stream
pub fn new(stream: S) -> WsStream<S> { pub fn new(stream: S) -> WsStream<S> {
WsStream { rx: PayloadHelper::new(stream), WsStream {
closed: false, rx: PayloadHelper::new(stream),
max_size: 65_536, closed: false,
max_size: 65_536,
} }
} }
@ -267,13 +280,16 @@ impl<S> WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
} }
} }
impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> { impl<S> Stream for WsStream<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = Message; type Item = Message;
type Error = ProtocolError; type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.closed { if self.closed {
return Ok(Async::Ready(None)) return Ok(Async::Ready(None));
} }
match Frame::parse(&mut self.rx, true, self.max_size) { match Frame::parse(&mut self.rx, true, self.max_size) {
@ -283,7 +299,7 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
// continuation is not supported // continuation is not supported
if !finished { if !finished {
self.closed = true; self.closed = true;
return Err(ProtocolError::NoContinuation) return Err(ProtocolError::NoContinuation);
} }
match opcode { match opcode {
@ -295,23 +311,21 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
OpCode::Close => { OpCode::Close => {
self.closed = true; self.closed = true;
let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16;
Ok(Async::Ready( Ok(Async::Ready(Some(Message::Close(CloseCode::from(
Some(Message::Close(CloseCode::from(code))))) code,
}, )))))
OpCode::Ping => }
Ok(Async::Ready(Some( OpCode::Ping => Ok(Async::Ready(Some(Message::Ping(
Message::Ping( String::from_utf8_lossy(payload.as_ref()).into(),
String::from_utf8_lossy(payload.as_ref()).into())))), )))),
OpCode::Pong => OpCode::Pong => Ok(Async::Ready(Some(Message::Pong(
Ok(Async::Ready(Some( String::from_utf8_lossy(payload.as_ref()).into(),
Message::Pong(String::from_utf8_lossy(payload.as_ref()).into())))), )))),
OpCode::Binary => OpCode::Binary => Ok(Async::Ready(Some(Message::Binary(payload)))),
Ok(Async::Ready(Some(Message::Binary(payload)))),
OpCode::Text => { OpCode::Text => {
let tmp = Vec::from(payload.as_ref()); let tmp = Vec::from(payload.as_ref());
match String::from_utf8(tmp) { match String::from_utf8(tmp) {
Ok(s) => Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))),
Ok(Async::Ready(Some(Message::Text(s)))),
Err(_) => { Err(_) => {
self.closed = true; self.closed = true;
Err(ProtocolError::BadEncoding) Err(ProtocolError::BadEncoding)
@ -333,77 +347,168 @@ impl<S> Stream for WsStream<S> where S: Stream<Item=Bytes, Error=PayloadError> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use http::{header, HeaderMap, Method, Uri, Version};
use std::str::FromStr; use std::str::FromStr;
use http::{Method, HeaderMap, Version, Uri, header};
#[test] #[test]
fn test_handshake() { fn test_handshake() {
let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(), let req = HttpRequest::new(
Version::HTTP_11, HeaderMap::new(), None); Method::POST,
assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert_eq!(
HandshakeError::GetMethodRequired,
handshake(&req).err().unwrap()
);
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), let req = HttpRequest::new(
Version::HTTP_11, HeaderMap::new(), None); Method::GET,
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); Uri::from_str("/").unwrap(),
Version::HTTP_11,
HeaderMap::new(),
None,
);
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(
header::HeaderValue::from_static("test")); header::UPGRADE,
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), header::HeaderValue::from_static("test"),
Version::HTTP_11, headers, None); );
assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(
header::HeaderValue::from_static("websocket")); header::UPGRADE,
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), header::HeaderValue::from_static("websocket"),
Version::HTTP_11, headers, None); );
assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap()); let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::NoConnectionUpgrade,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(
header::HeaderValue::from_static("websocket")); header::UPGRADE,
headers.insert(header::CONNECTION, header::HeaderValue::from_static("websocket"),
header::HeaderValue::from_static("upgrade")); );
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), headers.insert(
Version::HTTP_11, headers, None); header::CONNECTION,
assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap()); header::HeaderValue::from_static("upgrade"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::NoVersionHeader,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(
header::HeaderValue::from_static("websocket")); header::UPGRADE,
headers.insert(header::CONNECTION, header::HeaderValue::from_static("websocket"),
header::HeaderValue::from_static("upgrade")); );
headers.insert(header::SEC_WEBSOCKET_VERSION, headers.insert(
header::HeaderValue::from_static("5")); header::CONNECTION,
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), header::HeaderValue::from_static("upgrade"),
Version::HTTP_11, headers, None); );
assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap()); headers.insert(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("5"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::UnsupportedVersion,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(
header::HeaderValue::from_static("websocket")); header::UPGRADE,
headers.insert(header::CONNECTION, header::HeaderValue::from_static("websocket"),
header::HeaderValue::from_static("upgrade")); );
headers.insert(header::SEC_WEBSOCKET_VERSION, headers.insert(
header::HeaderValue::from_static("13")); header::CONNECTION,
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), header::HeaderValue::from_static("upgrade"),
Version::HTTP_11, headers, None); );
assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap()); headers.insert(
header::SEC_WEBSOCKET_VERSION,
header::HeaderValue::from_static("13"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
HandshakeError::BadWebsocketKey,
handshake(&req).err().unwrap()
);
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::UPGRADE, headers.insert(
header::HeaderValue::from_static("websocket")); header::UPGRADE,
headers.insert(header::CONNECTION, header::HeaderValue::from_static("websocket"),
header::HeaderValue::from_static("upgrade")); );
headers.insert(header::SEC_WEBSOCKET_VERSION, headers.insert(
header::HeaderValue::from_static("13")); header::CONNECTION,
headers.insert(header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("upgrade"),
header::HeaderValue::from_static("13")); );
let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), headers.insert(
Version::HTTP_11, headers, None); header::SEC_WEBSOCKET_VERSION,
assert_eq!(StatusCode::SWITCHING_PROTOCOLS, header::HeaderValue::from_static("13"),
handshake(&req).unwrap().finish().status()); );
headers.insert(
header::SEC_WEBSOCKET_KEY,
header::HeaderValue::from_static("13"),
);
let req = HttpRequest::new(
Method::GET,
Uri::from_str("/").unwrap(),
Version::HTTP_11,
headers,
None,
);
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().status()
);
} }
#[test] #[test]

View File

@ -1,7 +1,7 @@
use std::fmt;
use std::convert::{Into, From};
use sha1;
use base64; use base64;
use sha1;
use std::convert::{From, Into};
use std::fmt;
use self::OpCode::*; use self::OpCode::*;
/// Operation codes as part of rfc6455. /// Operation codes as part of rfc6455.
@ -26,52 +26,54 @@ pub enum OpCode {
impl fmt::Display for OpCode { impl fmt::Display for OpCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
Continue => write!(f, "CONTINUE"), Continue => write!(f, "CONTINUE"),
Text => write!(f, "TEXT"), Text => write!(f, "TEXT"),
Binary => write!(f, "BINARY"), Binary => write!(f, "BINARY"),
Close => write!(f, "CLOSE"), Close => write!(f, "CLOSE"),
Ping => write!(f, "PING"), Ping => write!(f, "PING"),
Pong => write!(f, "PONG"), Pong => write!(f, "PONG"),
Bad => write!(f, "BAD"), Bad => write!(f, "BAD"),
} }
} }
} }
impl Into<u8> for OpCode { impl Into<u8> for OpCode {
fn into(self) -> u8 { fn into(self) -> u8 {
match self { match self {
Continue => 0, Continue => 0,
Text => 1, Text => 1,
Binary => 2, Binary => 2,
Close => 8, Close => 8,
Ping => 9, Ping => 9,
Pong => 10, Pong => 10,
Bad => { Bad => {
debug_assert!(false, "Attempted to convert invalid opcode to u8. This is a bug."); debug_assert!(
8 // if this somehow happens, a close frame will help us tear down quickly false,
"Attempted to convert invalid opcode to u8. This is a bug."
);
8 // if this somehow happens, a close frame will help us tear down quickly
} }
} }
} }
} }
impl From<u8> for OpCode { impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode { fn from(byte: u8) -> OpCode {
match byte { match byte {
0 => Continue, 0 => Continue,
1 => Text, 1 => Text,
2 => Binary, 2 => Binary,
8 => Close, 8 => Close,
9 => Ping, 9 => Ping,
10 => Pong, 10 => Pong,
_ => Bad _ => Bad,
} }
} }
} }
use self::CloseCode::*; use self::CloseCode::*;
/// Status code used to indicate why an endpoint is closing the `WebSocket` connection. /// Status code used to indicate why an endpoint is closing the `WebSocket`
/// connection.
#[derive(Debug, Eq, PartialEq, Clone, Copy)] #[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum CloseCode { pub enum CloseCode {
/// Indicates a normal closure, meaning that the purpose for /// Indicates a normal closure, meaning that the purpose for
@ -125,12 +127,13 @@ pub enum CloseCode {
/// it encountered an unexpected condition that prevented it from /// it encountered an unexpected condition that prevented it from
/// fulfilling the request. /// fulfilling the request.
Error, Error,
/// Indicates that the server is restarting. A client may choose to reconnect, /// Indicates that the server is restarting. A client may choose to
/// and if it does, it should use a randomized delay of 5-30 seconds between attempts. /// reconnect, and if it does, it should use a randomized delay of 5-30
/// seconds between attempts.
Restart, Restart,
/// Indicates that the server is overloaded and the client should either connect /// Indicates that the server is overloaded and the client should either
/// to a different IP (when multiple targets exist), or reconnect to the same IP /// connect to a different IP (when multiple targets exist), or
/// when a user has performed an action. /// reconnect to the same IP when a user has performed an action.
Again, Again,
#[doc(hidden)] #[doc(hidden)]
Tls, Tls,
@ -141,31 +144,29 @@ pub enum CloseCode {
} }
impl Into<u16> for CloseCode { impl Into<u16> for CloseCode {
fn into(self) -> u16 { fn into(self) -> u16 {
match self { match self {
Normal => 1000, Normal => 1000,
Away => 1001, Away => 1001,
Protocol => 1002, Protocol => 1002,
Unsupported => 1003, Unsupported => 1003,
Status => 1005, Status => 1005,
Abnormal => 1006, Abnormal => 1006,
Invalid => 1007, Invalid => 1007,
Policy => 1008, Policy => 1008,
Size => 1009, Size => 1009,
Extension => 1010, Extension => 1010,
Error => 1011, Error => 1011,
Restart => 1012, Restart => 1012,
Again => 1013, Again => 1013,
Tls => 1015, Tls => 1015,
Empty => 0, Empty => 0,
Other(code) => code, Other(code) => code,
} }
} }
} }
impl From<u16> for CloseCode { impl From<u16> for CloseCode {
fn from(code: u16) -> CloseCode { fn from(code: u16) -> CloseCode {
match code { match code {
1000 => Normal, 1000 => Normal,
@ -182,7 +183,7 @@ impl From<u16> for CloseCode {
1012 => Restart, 1012 => Restart,
1013 => Again, 1013 => Again,
1015 => Tls, 1015 => Tls,
0 => Empty, 0 => Empty,
_ => Other(code), _ => Other(code),
} }
} }
@ -200,7 +201,6 @@ pub(crate) fn hash_key(key: &[u8]) -> String {
base64::encode(&hasher.digest().bytes()) base64::encode(&hasher.digest().bytes())
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
#![allow(unused_imports, unused_variables, dead_code)] #![allow(unused_imports, unused_variables, dead_code)]
@ -210,9 +210,9 @@ mod test {
($from:expr => $opcode:pat) => { ($from:expr => $opcode:pat) => {
match OpCode::from($from) { match OpCode::from($from) {
e @ $opcode => (), e @ $opcode => (),
e => unreachable!("{:?}", e) e => unreachable!("{:?}", e),
} }
} };
} }
macro_rules! opcode_from { macro_rules! opcode_from {
@ -220,9 +220,9 @@ mod test {
let res: u8 = $from.into(); let res: u8 = $from.into();
match res { match res {
e @ $opcode => (), e @ $opcode => (),
e => unreachable!("{:?}", e) e => unreachable!("{:?}", e),
} }
} };
} }
#[test] #[test]

View File

@ -1,49 +1,46 @@
extern crate actix; extern crate actix;
extern crate actix_web; extern crate actix_web;
extern crate bytes; extern crate bytes;
extern crate futures;
extern crate flate2; extern crate flate2;
extern crate futures;
extern crate rand; extern crate rand;
use std::io::Read; use std::io::Read;
use bytes::Bytes; use bytes::Bytes;
use flate2::read::GzDecoder;
use futures::Future; use futures::Future;
use futures::stream::once; use futures::stream::once;
use flate2::read::GzDecoder;
use rand::Rng; use rand::Rng;
use actix_web::*; use actix_web::*;
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
const STR: &str = Hello World Hello World Hello World Hello World Hello World \
"Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World";
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test] #[test]
fn test_simple() { fn test_simple() {
let mut srv = test::TestServer::new( let mut srv =
|app| app.handler(|_| HttpResponse::Ok().body(STR))); test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let request = srv.get().header("x-test", "111").finish().unwrap(); let request = srv.get().header("x-test", "111").finish().unwrap();
let repr = format!("{:?}", request); let repr = format!("{:?}", request);
@ -68,23 +65,26 @@ fn test_simple() {
#[test] #[test]
fn test_with_query_parameter() { fn test_with_query_parameter() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|req: HttpRequest| match req.query().get("qp") { app.handler(|req: HttpRequest| match req.query().get("qp") {
Some(_) => HttpResponse::Ok().finish(), Some(_) => HttpResponse::Ok().finish(),
None => HttpResponse::BadRequest().finish(), None => HttpResponse::BadRequest().finish(),
})); })
});
let request = srv.get().uri(srv.url("/?qp=5").as_str()).finish().unwrap(); let request = srv.get()
.uri(srv.url("/?qp=5").as_str())
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
#[test] #[test]
fn test_no_decompress() { fn test_no_decompress() {
let mut srv = test::TestServer::new( let mut srv =
|app| app.handler(|_| HttpResponse::Ok().body(STR))); test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -111,19 +111,23 @@ fn test_no_decompress() {
#[test] #[test]
fn test_client_gzip_encoding() { fn test_client_gzip_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Deflate) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Deflate)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.post() let request = srv.post()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(STR).unwrap(); .body(STR)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -136,19 +140,23 @@ fn test_client_gzip_encoding() {
fn test_client_gzip_encoding_large() { fn test_client_gzip_encoding_large() {
let data = STR.repeat(10); let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Deflate) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Deflate)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.post() let request = srv.post()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(data.clone()).unwrap(); .body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -164,19 +172,23 @@ fn test_client_gzip_encoding_large_random() {
.take(100_000) .take(100_000)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Deflate) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Deflate)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.post() let request = srv.post()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(data.clone()).unwrap(); .body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -185,22 +197,26 @@ fn test_client_gzip_encoding_large_random() {
assert_eq!(bytes, Bytes::from(data)); assert_eq!(bytes, Bytes::from(data));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_client_brotli_encoding() { fn test_client_brotli_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Gzip) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Gzip)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.client(http::Method::POST, "/") let request = srv.client(http::Method::POST, "/")
.content_encoding(http::ContentEncoding::Br) .content_encoding(http::ContentEncoding::Br)
.body(STR).unwrap(); .body(STR)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -209,7 +225,7 @@ fn test_client_brotli_encoding() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_client_brotli_encoding_large_random() { fn test_client_brotli_encoding_large_random() {
let data = rand::thread_rng() let data = rand::thread_rng()
@ -217,19 +233,23 @@ fn test_client_brotli_encoding_large_random() {
.take(70_000) .take(70_000)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(move |bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(move |bytes: Bytes| {
.content_encoding(http::ContentEncoding::Gzip) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Gzip)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.client(http::Method::POST, "/") let request = srv.client(http::Method::POST, "/")
.content_encoding(http::ContentEncoding::Br) .content_encoding(http::ContentEncoding::Br)
.body(data.clone()).unwrap(); .body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -239,22 +259,26 @@ fn test_client_brotli_encoding_large_random() {
assert_eq!(bytes, Bytes::from(data)); assert_eq!(bytes, Bytes::from(data));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_client_deflate_encoding() { fn test_client_deflate_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Br) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Br)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.post() let request = srv.post()
.content_encoding(http::ContentEncoding::Deflate) .content_encoding(http::ContentEncoding::Deflate)
.body(STR).unwrap(); .body(STR)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -263,7 +287,7 @@ fn test_client_deflate_encoding() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_client_deflate_encoding_large_random() { fn test_client_deflate_encoding_large_random() {
let data = rand::thread_rng() let data = rand::thread_rng()
@ -271,19 +295,23 @@ fn test_client_deflate_encoding_large_random() {
.take(70_000) .take(70_000)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Br) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Br)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let request = srv.post() let request = srv.post()
.content_encoding(http::ContentEncoding::Deflate) .content_encoding(http::ContentEncoding::Deflate)
.body(data.clone()).unwrap(); .body(data.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -294,20 +322,25 @@ fn test_client_deflate_encoding_large_random() {
#[test] #[test]
fn test_client_streaming_explicit() { fn test_client_streaming_explicit() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler( app.handler(|req: HttpRequest| {
|req: HttpRequest| req.body() req.body()
.map_err(Error::from) .map_err(Error::from)
.and_then(|body| { .and_then(|body| {
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
.chunked() .chunked()
.content_encoding(http::ContentEncoding::Identity) .content_encoding(http::ContentEncoding::Identity)
.body(body))}) .body(body))
.responder())); })
.responder()
})
});
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
let request = srv.get().body(Body::Streaming(Box::new(body))).unwrap(); let request = srv.get()
.body(Body::Streaming(Box::new(body)))
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -318,12 +351,14 @@ fn test_client_streaming_explicit() {
#[test] #[test]
fn test_body_streaming_implicit() { fn test_body_streaming_implicit() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(Body::Streaming(Box::new(body)))})); .body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -338,7 +373,7 @@ fn test_body_streaming_implicit() {
fn test_client_cookie_handling() { fn test_client_cookie_handling() {
use actix_web::http::Cookie; use actix_web::http::Cookie;
fn err() -> Error { fn err() -> Error {
use std::io::{ErrorKind, Error as IoError}; use std::io::{Error as IoError, ErrorKind};
// stub some generic error // stub some generic error
Error::from(IoError::from(ErrorKind::NotFound)) Error::from(IoError::from(ErrorKind::NotFound))
} }
@ -352,13 +387,12 @@ fn test_client_cookie_handling() {
// Q: are all these clones really necessary? A: Yes, possibly // Q: are all these clones really necessary? A: Yes, possibly
let cookie1b = cookie1.clone(); let cookie1b = cookie1.clone();
let cookie2b = cookie2.clone(); let cookie2b = cookie2.clone();
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(move |app| {
move |app| { let cookie1 = cookie1b.clone();
let cookie1 = cookie1b.clone(); let cookie2 = cookie2b.clone();
let cookie2 = cookie2b.clone(); app.handler(move |req: HttpRequest| {
app.handler(move |req: HttpRequest| { // Check cookies were sent correctly
// Check cookies were sent correctly req.cookie("cookie1").ok_or_else(err)
req.cookie("cookie1").ok_or_else(err)
.and_then(|c1| if c1.value() == "value1" { .and_then(|c1| if c1.value() == "value1" {
Ok(()) Ok(())
} else { } else {
@ -376,8 +410,8 @@ fn test_client_cookie_handling() {
.cookie(cookie2.clone()) .cookie(cookie2.clone())
.finish() .finish()
) )
}) })
}); });
let request = srv.get() let request = srv.get()
.cookie(cookie1.clone()) .cookie(cookie1.clone())

View File

@ -1,11 +1,12 @@
extern crate actix; extern crate actix;
extern crate actix_web; extern crate actix_web;
extern crate tokio_core; extern crate bytes;
extern crate futures; extern crate futures;
extern crate h2; extern crate h2;
extern crate http; extern crate http;
extern crate bytes; extern crate tokio_core;
#[macro_use] extern crate serde_derive; #[macro_use]
extern crate serde_derive;
use actix_web::*; use actix_web::*;
use bytes::Bytes; use bytes::Bytes;
@ -19,15 +20,16 @@ struct PParam {
#[test] #[test]
fn test_path_extractor() { fn test_path_extractor() {
let mut srv = test::TestServer::new(|app| { let mut srv = test::TestServer::new(|app| {
app.resource( app.resource("/{username}/index.html", |r| {
"/{username}/index.html", |r| r.with( r.with(|p: Path<PParam>| format!("Welcome {}!", p.username))
|p: Path<PParam>| format!("Welcome {}!", p.username))); });
} });
);
// client request // client request
let request = srv.get().uri(srv.url("/test/index.html")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/test/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -39,15 +41,16 @@ fn test_path_extractor() {
#[test] #[test]
fn test_query_extractor() { fn test_query_extractor() {
let mut srv = test::TestServer::new(|app| { let mut srv = test::TestServer::new(|app| {
app.resource( app.resource("/index.html", |r| {
"/index.html", |r| r.with( r.with(|p: Query<PParam>| format!("Welcome {}!", p.username))
|p: Query<PParam>| format!("Welcome {}!", p.username))); });
} });
);
// client request // client request
let request = srv.get().uri(srv.url("/index.html?username=test")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/index.html?username=test"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -56,8 +59,10 @@ fn test_query_extractor() {
assert_eq!(bytes, Bytes::from_static(b"Welcome test!")); assert_eq!(bytes, Bytes::from_static(b"Welcome test!"));
// client request // client request
let request = srv.get().uri(srv.url("/index.html")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST); assert_eq!(response.status(), StatusCode::BAD_REQUEST);
} }
@ -65,16 +70,18 @@ fn test_query_extractor() {
#[test] #[test]
fn test_path_and_query_extractor() { fn test_path_and_query_extractor() {
let mut srv = test::TestServer::new(|app| { let mut srv = test::TestServer::new(|app| {
app.resource( app.resource("/{username}/index.html", |r| {
"/{username}/index.html", |r| r.route().with2( r.route().with2(|p: Path<PParam>, q: Query<PParam>| {
|p: Path<PParam>, q: Query<PParam>| format!("Welcome {} - {}!", p.username, q.username)
format!("Welcome {} - {}!", p.username, q.username))); })
} });
); });
// client request // client request
let request = srv.get().uri(srv.url("/test1/index.html?username=test2")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/test1/index.html?username=test2"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -83,8 +90,10 @@ fn test_path_and_query_extractor() {
assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!")); assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!"));
// client request // client request
let request = srv.get().uri(srv.url("/test1/index.html")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/test1/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST); assert_eq!(response.status(), StatusCode::BAD_REQUEST);
} }
@ -92,16 +101,19 @@ fn test_path_and_query_extractor() {
#[test] #[test]
fn test_path_and_query_extractor2() { fn test_path_and_query_extractor2() {
let mut srv = test::TestServer::new(|app| { let mut srv = test::TestServer::new(|app| {
app.resource( app.resource("/{username}/index.html", |r| {
"/{username}/index.html", |r| r.route().with3( r.route()
|_: HttpRequest, p: Path<PParam>, q: Query<PParam>| .with3(|_: HttpRequest, p: Path<PParam>, q: Query<PParam>| {
format!("Welcome {} - {}!", p.username, q.username))); format!("Welcome {} - {}!", p.username, q.username)
} })
); });
});
// client request // client request
let request = srv.get().uri(srv.url("/test1/index.html?username=test2")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/test1/index.html?username=test2"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -110,8 +122,10 @@ fn test_path_and_query_extractor2() {
assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!")); assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!"));
// client request // client request
let request = srv.get().uri(srv.url("/test1/index.html")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/test1/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST); assert_eq!(response.status(), StatusCode::BAD_REQUEST);
} }
@ -123,8 +137,10 @@ fn test_non_ascii_route() {
}); });
// client request // client request
let request = srv.get().uri(srv.url("/中文/index.html")) let request = srv.get()
.finish().unwrap(); .uri(srv.url("/中文/index.html"))
.finish()
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());

View File

@ -1,60 +1,58 @@
extern crate actix; extern crate actix;
extern crate actix_web; extern crate actix_web;
extern crate tokio_core; extern crate bytes;
extern crate flate2;
extern crate futures; extern crate futures;
extern crate h2; extern crate h2;
extern crate http as modhttp; extern crate http as modhttp;
extern crate bytes;
extern crate flate2;
extern crate rand; extern crate rand;
extern crate tokio_core;
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
extern crate brotli2; extern crate brotli2;
use std::{net, thread, time}; #[cfg(feature = "brotli")]
use std::io::{Read, Write}; use brotli2::write::{BrotliDecoder, BrotliEncoder};
use std::sync::{Arc, mpsc}; use bytes::{Bytes, BytesMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use flate2::Compression; use flate2::Compression;
use flate2::read::GzDecoder; use flate2::read::GzDecoder;
use flate2::write::{GzEncoder, DeflateEncoder, DeflateDecoder}; use flate2::write::{DeflateDecoder, DeflateEncoder, GzEncoder};
#[cfg(feature="brotli")]
use brotli2::write::{BrotliEncoder, BrotliDecoder};
use futures::{Future, Stream};
use futures::stream::once; use futures::stream::once;
use futures::{Future, Stream};
use h2::client as h2client; use h2::client as h2client;
use bytes::{Bytes, BytesMut};
use modhttp::Request; use modhttp::Request;
use rand::Rng;
use std::io::{Read, Write};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{mpsc, Arc};
use std::{net, thread, time};
use tokio_core::net::TcpStream; use tokio_core::net::TcpStream;
use tokio_core::reactor::Core; use tokio_core::reactor::Core;
use rand::Rng;
use actix::System; use actix::System;
use actix_web::*; use actix_web::*;
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
const STR: &str = Hello World Hello World Hello World Hello World Hello World \
"Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World";
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test] #[test]
fn test_start() { fn test_start() {
@ -63,11 +61,13 @@ fn test_start() {
thread::spawn(move || { thread::spawn(move || {
let sys = System::new("test"); let sys = System::new("test");
let srv = server::new( let srv = server::new(|| {
|| vec![App::new() vec![
.resource( App::new().resource("/", |r| {
"/", |r| r.method(http::Method::GET) r.method(http::Method::GET).f(|_| HttpResponse::Ok())
.f(|_|HttpResponse::Ok()))]); }),
]
});
let srv = srv.bind("127.0.0.1:0").unwrap(); let srv = srv.bind("127.0.0.1:0").unwrap();
let addr = srv.addrs()[0]; let addr = srv.addrs()[0];
@ -79,7 +79,9 @@ fn test_start() {
let mut sys = System::new("test-server"); let mut sys = System::new("test-server");
{ {
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
.finish()
.unwrap();
let response = sys.run_until_complete(req.send()).unwrap(); let response = sys.run_until_complete(req.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
@ -94,7 +96,9 @@ fn test_start() {
thread::sleep(time::Duration::from_millis(200)); thread::sleep(time::Duration::from_millis(200));
{ {
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
.finish()
.unwrap();
let response = sys.run_until_complete(req.send()).unwrap(); let response = sys.run_until_complete(req.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
@ -108,10 +112,13 @@ fn test_shutdown() {
thread::spawn(move || { thread::spawn(move || {
let sys = System::new("test"); let sys = System::new("test");
let srv = server::new( let srv = server::new(|| {
|| vec![App::new() vec![
.resource( App::new().resource("/", |r| {
"/", |r| r.method(http::Method::GET).f(|_| HttpResponse::Ok()))]); r.method(http::Method::GET).f(|_| HttpResponse::Ok())
}),
]
});
let srv = srv.bind("127.0.0.1:0").unwrap(); let srv = srv.bind("127.0.0.1:0").unwrap();
let addr = srv.addrs()[0]; let addr = srv.addrs()[0];
@ -124,9 +131,11 @@ fn test_shutdown() {
let mut sys = System::new("test-server"); let mut sys = System::new("test-server");
{ {
let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
.finish()
.unwrap();
let response = sys.run_until_complete(req.send()).unwrap(); let response = sys.run_until_complete(req.send()).unwrap();
srv_addr.do_send(server::StopServer{graceful: true}); srv_addr.do_send(server::StopServer { graceful: true });
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
@ -146,30 +155,31 @@ fn test_simple() {
fn test_headers() { fn test_headers() {
let data = STR.repeat(10); let data = STR.repeat(10);
let srv_data = Arc::new(data.clone()); let srv_data = Arc::new(data.clone());
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(move |app| {
move |app| { let data = srv_data.clone();
let data = srv_data.clone(); app.handler(move |_| {
app.handler(move |_| { let mut builder = HttpResponse::Ok();
let mut builder = HttpResponse::Ok(); for idx in 0..90 {
for idx in 0..90 { builder.header(
builder.header( format!("X-TEST-{}", idx).as_str(),
format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST "); );
} }
builder.body(data.as_ref())}) builder.body(data.as_ref())
}); })
});
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -182,8 +192,8 @@ fn test_headers() {
#[test] #[test]
fn test_body() { fn test_body() {
let mut srv = test::TestServer::new( let mut srv =
|app| app.handler(|_| HttpResponse::Ok().body(STR))); test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -196,11 +206,13 @@ fn test_body() {
#[test] #[test]
fn test_body_gzip() { fn test_body_gzip() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler( app.handler(|_| {
|_| HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(STR))); .body(STR)
})
});
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -221,13 +233,14 @@ fn test_body_gzip_large() {
let data = STR.repeat(10); let data = STR.repeat(10);
let srv_data = Arc::new(data.clone()); let srv_data = Arc::new(data.clone());
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(move |app| {
move |app| { let data = srv_data.clone();
let data = srv_data.clone(); app.handler(move |_| {
app.handler( HttpResponse::Ok()
move |_| HttpResponse::Ok() .content_encoding(http::ContentEncoding::Gzip)
.content_encoding(http::ContentEncoding::Gzip) .body(data.as_ref())
.body(data.as_ref()))}); })
});
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -251,13 +264,14 @@ fn test_body_gzip_large_random() {
.collect::<String>(); .collect::<String>();
let srv_data = Arc::new(data.clone()); let srv_data = Arc::new(data.clone());
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(move |app| {
move |app| { let data = srv_data.clone();
let data = srv_data.clone(); app.handler(move |_| {
app.handler( HttpResponse::Ok()
move |_| HttpResponse::Ok() .content_encoding(http::ContentEncoding::Gzip)
.content_encoding(http::ContentEncoding::Gzip) .body(data.as_ref())
.body(data.as_ref()))}); })
});
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -276,12 +290,14 @@ fn test_body_gzip_large_random() {
#[test] #[test]
fn test_body_chunked_implicit() { fn test_body_chunked_implicit() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(Body::Streaming(Box::new(body)))})); .body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -297,15 +313,17 @@ fn test_body_chunked_implicit() {
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_body_br_streaming() { fn test_body_br_streaming() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br) .content_encoding(http::ContentEncoding::Br)
.body(Body::Streaming(Box::new(body)))})); .body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -323,17 +341,23 @@ fn test_body_br_streaming() {
#[test] #[test]
fn test_head_empty() { fn test_head_empty() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
HttpResponse::Ok() HttpResponse::Ok()
.content_length(STR.len() as u64).finish()})); .content_length(STR.len() as u64)
.finish()
})
});
let request = srv.head().finish().unwrap(); let request = srv.head().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap(); let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
} }
@ -344,18 +368,24 @@ fn test_head_empty() {
#[test] #[test]
fn test_head_binary() { fn test_head_binary() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity) .content_encoding(http::ContentEncoding::Identity)
.content_length(100).body(STR)})); .content_length(100)
.body(STR)
})
});
let request = srv.head().finish().unwrap(); let request = srv.head().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap(); let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
} }
@ -366,32 +396,38 @@ fn test_head_binary() {
#[test] #[test]
fn test_head_binary2() { fn test_head_binary2() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Identity) .content_encoding(http::ContentEncoding::Identity)
.body(STR) .body(STR)
})); })
});
let request = srv.head().finish().unwrap(); let request = srv.head().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap(); let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
} }
} }
#[test] #[test]
fn test_body_length() { fn test_body_length() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok() HttpResponse::Ok()
.content_length(STR.len() as u64) .content_length(STR.len() as u64)
.content_encoding(http::ContentEncoding::Identity) .content_encoding(http::ContentEncoding::Identity)
.body(Body::Streaming(Box::new(body)))})); .body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -404,13 +440,15 @@ fn test_body_length() {
#[test] #[test]
fn test_body_chunked_explicit() { fn test_body_chunked_explicit() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|_| { app.handler(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
HttpResponse::Ok() HttpResponse::Ok()
.chunked() .chunked()
.content_encoding(http::ContentEncoding::Gzip) .content_encoding(http::ContentEncoding::Gzip)
.body(Body::Streaming(Box::new(body)))})); .body(Body::Streaming(Box::new(body)))
})
});
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -428,11 +466,13 @@ fn test_body_chunked_explicit() {
#[test] #[test]
fn test_body_deflate() { fn test_body_deflate() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler( app.handler(|_| {
|_| HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Deflate) .content_encoding(http::ContentEncoding::Deflate)
.body(STR))); .body(STR)
})
});
// client request // client request
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
@ -449,14 +489,16 @@ fn test_body_deflate() {
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_body_brotli() { fn test_body_brotli() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler( app.handler(|_| {
|_| HttpResponse::Ok() HttpResponse::Ok()
.content_encoding(http::ContentEncoding::Br) .content_encoding(http::ContentEncoding::Br)
.body(STR))); .body(STR)
})
});
// client request // client request
let request = srv.get().disable_decompress().finish().unwrap(); let request = srv.get().disable_decompress().finish().unwrap();
@ -475,14 +517,17 @@ fn test_body_brotli() {
#[test] #[test]
fn test_gzip_encoding() { fn test_gzip_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let mut e = GzEncoder::new(Vec::new(), Compression::default()); let mut e = GzEncoder::new(Vec::new(), Compression::default());
@ -491,7 +536,8 @@ fn test_gzip_encoding() {
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "gzip") .header(http::header::CONTENT_ENCODING, "gzip")
.body(enc.clone()).unwrap(); .body(enc.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -503,14 +549,17 @@ fn test_gzip_encoding() {
#[test] #[test]
fn test_gzip_encoding_large() { fn test_gzip_encoding_large() {
let data = STR.repeat(10); let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let mut e = GzEncoder::new(Vec::new(), Compression::default()); let mut e = GzEncoder::new(Vec::new(), Compression::default());
@ -519,7 +568,8 @@ fn test_gzip_encoding_large() {
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "gzip") .header(http::header::CONTENT_ENCODING, "gzip")
.body(enc.clone()).unwrap(); .body(enc.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -535,14 +585,17 @@ fn test_reading_gzip_encoding_large_random() {
.take(60_000) .take(60_000)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
// client request // client request
let mut e = GzEncoder::new(Vec::new(), Compression::default()); let mut e = GzEncoder::new(Vec::new(), Compression::default());
@ -551,7 +604,8 @@ fn test_reading_gzip_encoding_large_random() {
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "gzip") .header(http::header::CONTENT_ENCODING, "gzip")
.body(enc.clone()).unwrap(); .body(enc.clone())
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -563,14 +617,17 @@ fn test_reading_gzip_encoding_large_random() {
#[test] #[test]
fn test_reading_deflate_encoding() { fn test_reading_deflate_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); let mut e = DeflateEncoder::new(Vec::new(), Compression::default());
e.write_all(STR.as_ref()).unwrap(); e.write_all(STR.as_ref()).unwrap();
@ -579,7 +636,8 @@ fn test_reading_deflate_encoding() {
// client request // client request
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "deflate") .header(http::header::CONTENT_ENCODING, "deflate")
.body(enc).unwrap(); .body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -591,14 +649,17 @@ fn test_reading_deflate_encoding() {
#[test] #[test]
fn test_reading_deflate_encoding_large() { fn test_reading_deflate_encoding_large() {
let data = STR.repeat(10); let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); let mut e = DeflateEncoder::new(Vec::new(), Compression::default());
e.write_all(data.as_ref()).unwrap(); e.write_all(data.as_ref()).unwrap();
@ -607,7 +668,8 @@ fn test_reading_deflate_encoding_large() {
// client request // client request
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "deflate") .header(http::header::CONTENT_ENCODING, "deflate")
.body(enc).unwrap(); .body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -623,14 +685,17 @@ fn test_reading_deflate_encoding_large_random() {
.take(160_000) .take(160_000)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); let mut e = DeflateEncoder::new(Vec::new(), Compression::default());
e.write_all(data.as_ref()).unwrap(); e.write_all(data.as_ref()).unwrap();
@ -639,7 +704,8 @@ fn test_reading_deflate_encoding_large_random() {
// client request // client request
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "deflate") .header(http::header::CONTENT_ENCODING, "deflate")
.body(enc).unwrap(); .body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -649,17 +715,20 @@ fn test_reading_deflate_encoding_large_random() {
assert_eq!(bytes, Bytes::from(data)); assert_eq!(bytes, Bytes::from(data));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_brotli_encoding() { fn test_brotli_encoding() {
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
let mut e = BrotliEncoder::new(Vec::new(), 5); let mut e = BrotliEncoder::new(Vec::new(), 5);
e.write_all(STR.as_ref()).unwrap(); e.write_all(STR.as_ref()).unwrap();
@ -668,7 +737,8 @@ fn test_brotli_encoding() {
// client request // client request
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "br") .header(http::header::CONTENT_ENCODING, "br")
.body(enc).unwrap(); .body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -677,18 +747,21 @@ fn test_brotli_encoding() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[cfg(feature="brotli")] #[cfg(feature = "brotli")]
#[test] #[test]
fn test_brotli_encoding_large() { fn test_brotli_encoding_large() {
let data = STR.repeat(10); let data = STR.repeat(10);
let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { let mut srv = test::TestServer::new(|app| {
req.body() app.handler(|req: HttpRequest| {
.and_then(|bytes: Bytes| { req.body()
Ok(HttpResponse::Ok() .and_then(|bytes: Bytes| {
.content_encoding(http::ContentEncoding::Identity) Ok(HttpResponse::Ok()
.body(bytes)) .content_encoding(http::ContentEncoding::Identity)
}).responder()} .body(bytes))
)); })
.responder()
})
});
let mut e = BrotliEncoder::new(Vec::new(), 5); let mut e = BrotliEncoder::new(Vec::new(), 5);
e.write_all(data.as_ref()).unwrap(); e.write_all(data.as_ref()).unwrap();
@ -697,7 +770,8 @@ fn test_brotli_encoding_large() {
// client request // client request
let request = srv.post() let request = srv.post()
.header(http::header::CONTENT_ENCODING, "br") .header(http::header::CONTENT_ENCODING, "br")
.body(enc).unwrap(); .body(enc)
.unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
@ -708,48 +782,46 @@ fn test_brotli_encoding_large() {
#[test] #[test]
fn test_h2() { fn test_h2() {
let srv = test::TestServer::new(|app| app.handler(|_|{ let srv = test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR)));
HttpResponse::Ok().body(STR)
}));
let addr = srv.addr(); let addr = srv.addr();
let mut core = Core::new().unwrap(); let mut core = Core::new().unwrap();
let handle = core.handle(); let handle = core.handle();
let tcp = TcpStream::connect(&addr, &handle); let tcp = TcpStream::connect(&addr, &handle);
let tcp = tcp.then(|res| { let tcp = tcp.then(|res| h2client::handshake(res.unwrap()))
h2client::handshake(res.unwrap()) .then(move |res| {
}).then(move |res| { let (mut client, h2) = res.unwrap();
let (mut client, h2) = res.unwrap();
let request = Request::builder() let request = Request::builder()
.uri(format!("https://{}/", addr).as_str()) .uri(format!("https://{}/", addr).as_str())
.body(()) .body(())
.unwrap(); .unwrap();
let (response, _) = client.send_request(request, false).unwrap(); let (response, _) = client.send_request(request, false).unwrap();
// Spawn a task to run the conn... // Spawn a task to run the conn...
handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e)));
response.and_then(|response| { response.and_then(|response| {
assert_eq!(response.status(), http::StatusCode::OK); assert_eq!(response.status(), http::StatusCode::OK);
let (_, body) = response.into_parts(); let (_, body) = response.into_parts();
body.fold(BytesMut::new(), |mut b, c| -> Result<_, h2::Error> { body.fold(BytesMut::new(), |mut b, c| -> Result<_, h2::Error> {
b.extend(c); b.extend(c);
Ok(b) Ok(b)
})
}) })
}) });
});
let _res = core.run(tcp); let _res = core.run(tcp);
// assert_eq!(res.unwrap(), Bytes::from_static(STR.as_ref())); // assert_eq!(res.unwrap(), Bytes::from_static(STR.as_ref()));
} }
#[test] #[test]
fn test_application() { fn test_application() {
let mut srv = test::TestServer::with_factory( let mut srv = test::TestServer::with_factory(|| {
|| App::new().resource("/", |r| r.f(|_| HttpResponse::Ok()))); App::new().resource("/", |r| r.f(|_| HttpResponse::Ok()))
});
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -764,17 +836,28 @@ struct MiddlewareTest {
impl<S> middleware::Middleware<S> for MiddlewareTest { impl<S> middleware::Middleware<S> for MiddlewareTest {
fn start(&self, _: &mut HttpRequest<S>) -> Result<middleware::Started> { fn start(&self, _: &mut HttpRequest<S>) -> Result<middleware::Started> {
self.start.store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); self.start.store(
self.start.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
Ok(middleware::Started::Done) Ok(middleware::Started::Done)
} }
fn response(&self, _: &mut HttpRequest<S>, resp: HttpResponse) -> Result<middleware::Response> { fn response(
self.response.store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed); &self, _: &mut HttpRequest<S>, resp: HttpResponse
) -> Result<middleware::Response> {
self.response.store(
self.response.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
Ok(middleware::Response::Done(resp)) Ok(middleware::Response::Done(resp))
} }
fn finish(&self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished { fn finish(&self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished {
self.finish.store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); self.finish.store(
self.finish.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
middleware::Finished::Done middleware::Finished::Done
} }
} }
@ -789,12 +872,13 @@ fn test_middlewares() {
let act_num2 = Arc::clone(&num2); let act_num2 = Arc::clone(&num2);
let act_num3 = Arc::clone(&num3); let act_num3 = Arc::clone(&num3);
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(move |app| {
move |app| app.middleware(MiddlewareTest{start: Arc::clone(&act_num1), app.middleware(MiddlewareTest {
response: Arc::clone(&act_num2), start: Arc::clone(&act_num1),
finish: Arc::clone(&act_num3)}) response: Arc::clone(&act_num2),
.handler(|_| HttpResponse::Ok()) finish: Arc::clone(&act_num3),
); }).handler(|_| HttpResponse::Ok())
});
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();
@ -805,7 +889,6 @@ fn test_middlewares() {
assert_eq!(num3.load(Ordering::Relaxed), 1); assert_eq!(num3.load(Ordering::Relaxed), 1);
} }
#[test] #[test]
fn test_resource_middlewares() { fn test_resource_middlewares() {
let num1 = Arc::new(AtomicUsize::new(0)); let num1 = Arc::new(AtomicUsize::new(0));
@ -816,13 +899,13 @@ fn test_resource_middlewares() {
let act_num2 = Arc::clone(&num2); let act_num2 = Arc::clone(&num2);
let act_num3 = Arc::clone(&num3); let act_num3 = Arc::clone(&num3);
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(move |app| {
move |app| app app.middleware(MiddlewareTest {
.middleware(MiddlewareTest{start: Arc::clone(&act_num1), start: Arc::clone(&act_num1),
response: Arc::clone(&act_num2), response: Arc::clone(&act_num2),
finish: Arc::clone(&act_num3)}) finish: Arc::clone(&act_num3),
.handler(|_| HttpResponse::Ok()) }).handler(|_| HttpResponse::Ok())
); });
let request = srv.get().finish().unwrap(); let request = srv.get().finish().unwrap();
let response = srv.execute(request.send()).unwrap(); let response = srv.execute(request.send()).unwrap();

View File

@ -1,19 +1,19 @@
extern crate actix; extern crate actix;
extern crate actix_web; extern crate actix_web;
extern crate bytes;
extern crate futures; extern crate futures;
extern crate http; extern crate http;
extern crate bytes;
extern crate rand; extern crate rand;
use bytes::Bytes; use bytes::Bytes;
use futures::Stream; use futures::Stream;
use rand::Rng; use rand::Rng;
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
extern crate openssl; extern crate openssl;
use actix_web::*;
use actix::prelude::*; use actix::prelude::*;
use actix_web::*;
struct Ws; struct Ws;
@ -22,7 +22,6 @@ impl Actor for Ws {
} }
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws { impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {
ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Ping(msg) => ctx.pong(&msg),
@ -36,8 +35,7 @@ impl StreamHandler<ws::Message, ws::ProtocolError> for Ws {
#[test] #[test]
fn test_simple() { fn test_simple() {
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws)));
|app| app.handler(|req| ws::start(req, Ws)));
let (reader, mut writer) = srv.ws().unwrap(); let (reader, mut writer) = srv.ws().unwrap();
writer.text("text"); writer.text("text");
@ -46,7 +44,12 @@ fn test_simple() {
writer.binary(b"text".as_ref()); writer.binary(b"text".as_ref());
let (item, reader) = srv.execute(reader.into_future()).unwrap(); let (item, reader) = srv.execute(reader.into_future()).unwrap();
assert_eq!(item, Some(ws::Message::Binary(Bytes::from_static(b"text").into()))); assert_eq!(
item,
Some(ws::Message::Binary(
Bytes::from_static(b"text").into()
))
);
writer.ping("ping"); writer.ping("ping");
let (item, reader) = srv.execute(reader.into_future()).unwrap(); let (item, reader) = srv.execute(reader.into_future()).unwrap();
@ -64,8 +67,7 @@ fn test_large_text() {
.take(65_536) .take(65_536)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws)));
|app| app.handler(|req| ws::start(req, Ws)));
let (mut reader, mut writer) = srv.ws().unwrap(); let (mut reader, mut writer) = srv.ws().unwrap();
for _ in 0..100 { for _ in 0..100 {
@ -83,15 +85,17 @@ fn test_large_bin() {
.take(65_536) .take(65_536)
.collect::<String>(); .collect::<String>();
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws)));
|app| app.handler(|req| ws::start(req, Ws)));
let (mut reader, mut writer) = srv.ws().unwrap(); let (mut reader, mut writer) = srv.ws().unwrap();
for _ in 0..100 { for _ in 0..100 {
writer.binary(data.clone()); writer.binary(data.clone());
let (item, r) = srv.execute(reader.into_future()).unwrap(); let (item, r) = srv.execute(reader.into_future()).unwrap();
reader = r; reader = r;
assert_eq!(item, Some(ws::Message::Binary(Binary::from(data.clone())))); assert_eq!(
item,
Some(ws::Message::Binary(Binary::from(data.clone())))
);
} }
} }
@ -115,18 +119,19 @@ impl Ws2 {
} else { } else {
ctx.text("0".repeat(65_536)); ctx.text("0".repeat(65_536));
} }
ctx.drain().and_then(|_, act, ctx| { ctx.drain()
act.count += 1; .and_then(|_, act, ctx| {
if act.count != 10_000 { act.count += 1;
act.send(ctx); if act.count != 10_000 {
} act.send(ctx);
actix::fut::ok(()) }
}).wait(ctx); actix::fut::ok(())
})
.wait(ctx);
} }
} }
impl StreamHandler<ws::Message, ws::ProtocolError> for Ws2 { impl StreamHandler<ws::Message, ws::ProtocolError> for Ws2 {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {
ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Ping(msg) => ctx.pong(&msg),
@ -142,8 +147,17 @@ impl StreamHandler<ws::Message, ws::ProtocolError> for Ws2 {
fn test_server_send_text() { fn test_server_send_text() {
let data = Some(ws::Message::Text("0".repeat(65_536))); let data = Some(ws::Message::Text("0".repeat(65_536)));
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false}))); app.handler(|req| {
ws::start(
req,
Ws2 {
count: 0,
bin: false,
},
)
})
});
let (mut reader, _writer) = srv.ws().unwrap(); let (mut reader, _writer) = srv.ws().unwrap();
for _ in 0..10_000 { for _ in 0..10_000 {
@ -157,8 +171,17 @@ fn test_server_send_text() {
fn test_server_send_bin() { fn test_server_send_bin() {
let data = Some(ws::Message::Binary(Binary::from("0".repeat(65_536)))); let data = Some(ws::Message::Binary(Binary::from("0".repeat(65_536))));
let mut srv = test::TestServer::new( let mut srv = test::TestServer::new(|app| {
|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: true}))); app.handler(|req| {
ws::start(
req,
Ws2 {
count: 0,
bin: true,
},
)
})
});
let (mut reader, _writer) = srv.ws().unwrap(); let (mut reader, _writer) = srv.ws().unwrap();
for _ in 0..10_000 { for _ in 0..10_000 {
@ -169,19 +192,33 @@ fn test_server_send_bin() {
} }
#[test] #[test]
#[cfg(feature="alpn")] #[cfg(feature = "alpn")]
fn test_ws_server_ssl() { fn test_ws_server_ssl() {
extern crate openssl; extern crate openssl;
use openssl::ssl::{SslMethod, SslAcceptor, SslFiletype}; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
// load ssl keys // load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder.set_private_key_file("tests/key.pem", SslFiletype::PEM).unwrap(); builder
builder.set_certificate_chain_file("tests/cert.pem").unwrap(); .set_private_key_file("tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("tests/cert.pem")
.unwrap();
let mut srv = test::TestServer::build() let mut srv = test::TestServer::build()
.ssl(builder.build()) .ssl(builder.build())
.start(|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false}))); .start(|app| {
app.handler(|req| {
ws::start(
req,
Ws2 {
count: 0,
bin: false,
},
)
})
});
let (mut reader, _writer) = srv.ws().unwrap(); let (mut reader, _writer) = srv.ws().unwrap();
let data = Some(ws::Message::Text("0".repeat(65_536))); let data = Some(ws::Message::Text("0".repeat(65_536)));

View File

@ -3,25 +3,24 @@
#![allow(unused_variables)] #![allow(unused_variables)]
extern crate actix; extern crate actix;
extern crate actix_web; extern crate actix_web;
extern crate clap;
extern crate env_logger; extern crate env_logger;
extern crate futures; extern crate futures;
extern crate tokio_core; extern crate num_cpus;
extern crate url;
extern crate clap;
extern crate rand; extern crate rand;
extern crate time; extern crate time;
extern crate num_cpus; extern crate tokio_core;
extern crate url;
use std::time::Duration;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use futures::Future; use futures::Future;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use actix::prelude::*; use actix::prelude::*;
use actix_web::ws; use actix_web::ws;
fn main() { fn main() {
::std::env::set_var("RUST_LOG", "actix_web=info"); ::std::env::set_var("RUST_LOG", "actix_web=info");
let _ = env_logger::init(); let _ = env_logger::init();
@ -62,16 +61,20 @@ fn main() {
let sample_rate = parse_u64_default(matches.value_of("sample-rate"), 1) as usize; let sample_rate = parse_u64_default(matches.value_of("sample-rate"), 1) as usize;
let perf_counters = Arc::new(PerfCounters::new()); let perf_counters = Arc::new(PerfCounters::new());
let payload = Arc::new(thread_rng() let payload = Arc::new(
.gen_ascii_chars() thread_rng()
.take(payload_size) .gen_ascii_chars()
.collect::<String>()); .take(payload_size)
.collect::<String>(),
);
let sys = actix::System::new("ws-client"); let sys = actix::System::new("ws-client");
let _: () = Perf{counters: perf_counters.clone(), let _: () = Perf {
payload: payload.len(), counters: perf_counters.clone(),
sample_rate_secs: sample_rate}.start(); payload: payload.len(),
sample_rate_secs: sample_rate,
}.start();
for t in 0..threads { for t in 0..threads {
let pl = payload.clone(); let pl = payload.clone();
@ -79,46 +82,54 @@ fn main() {
let perf = perf_counters.clone(); let perf = perf_counters.clone();
let addr = Arbiter::new(format!("test {}", t)); let addr = Arbiter::new(format!("test {}", t));
addr.do_send(actix::msgs::Execute::new(move || -> Result<(), ()> { addr.do_send(actix::msgs::Execute::new(
for _ in 0..concurrency { move || -> Result<(), ()> {
let pl2 = pl.clone(); for _ in 0..concurrency {
let perf2 = perf.clone(); let pl2 = pl.clone();
let ws2 = ws.clone(); let perf2 = perf.clone();
let ws2 = ws.clone();
Arbiter::handle().spawn( Arbiter::handle().spawn(
ws::Client::new(&ws) ws::Client::new(&ws)
.write_buffer_capacity(0) .write_buffer_capacity(0)
.connect() .connect()
.map_err(|e| { .map_err(|e| {
println!("Error: {}", e); println!("Error: {}", e);
//Arbiter::system().do_send(actix::msgs::SystemExit(0)); //Arbiter::system().do_send(actix::msgs::SystemExit(0));
() ()
}) })
.map(move |(reader, writer)| { .map(move |(reader, writer)| {
let addr: Addr<Syn, _> = ChatClient::create(move |ctx| { let addr: Addr<Syn, _> =
ChatClient::add_stream(reader, ctx); ChatClient::create(move |ctx| {
ChatClient{url: ws2, ChatClient::add_stream(reader, ctx);
conn: writer, ChatClient {
payload: pl2, url: ws2,
bin: bin, conn: writer,
ts: time::precise_time_ns(), payload: pl2,
perf_counters: perf2, bin: bin,
sent: 0, ts: time::precise_time_ns(),
max_payload_size: max_payload_size, perf_counters: perf2,
} sent: 0,
}); max_payload_size: max_payload_size,
}) }
); });
} }),
Ok(()) );
})); }
Ok(())
},
));
} }
let res = sys.run(); let res = sys.run();
} }
fn parse_u64_default(input: Option<&str>, default: u64) -> u64 { fn parse_u64_default(input: Option<&str>, default: u64) -> u64 {
input.map(|v| v.parse().expect(&format!("not a valid number: {}", v))) input
.map(|v| {
v.parse()
.expect(&format!("not a valid number: {}", v))
})
.unwrap_or(default) .unwrap_or(default)
} }
@ -138,29 +149,32 @@ impl Actor for Perf {
impl Perf { impl Perf {
fn sample_rate(&self, ctx: &mut Context<Self>) { fn sample_rate(&self, ctx: &mut Context<Self>) {
ctx.run_later(Duration::new(self.sample_rate_secs as u64, 0), |act, ctx| { ctx.run_later(
let req_count = act.counters.pull_request_count(); Duration::new(self.sample_rate_secs as u64, 0),
if req_count != 0 { |act, ctx| {
let conns = act.counters.pull_connections_count(); let req_count = act.counters.pull_request_count();
let latency = act.counters.pull_latency_ns(); if req_count != 0 {
let latency_max = act.counters.pull_latency_max_ns(); let conns = act.counters.pull_connections_count();
println!( let latency = act.counters.pull_latency_ns();
"rate: {}, conns: {}, throughput: {:?} kb, latency: {}, latency max: {}", let latency_max = act.counters.pull_latency_max_ns();
req_count / act.sample_rate_secs, println!(
conns / act.sample_rate_secs, "rate: {}, conns: {}, throughput: {:?} kb, latency: {}, latency max: {}",
(((req_count * act.payload) as f64) / 1024.0) / req_count / act.sample_rate_secs,
act.sample_rate_secs as f64, conns / act.sample_rate_secs,
time::Duration::nanoseconds((latency / req_count as u64) as i64), (((req_count * act.payload) as f64) / 1024.0)
time::Duration::nanoseconds(latency_max as i64) / act.sample_rate_secs as f64,
); time::Duration::nanoseconds((latency / req_count as u64) as i64),
} time::Duration::nanoseconds(latency_max as i64)
);
}
act.sample_rate(ctx); act.sample_rate(ctx);
}); },
);
} }
} }
struct ChatClient{ struct ChatClient {
url: String, url: String,
conn: ws::ClientWriter, conn: ws::ClientWriter,
payload: Arc<String>, payload: Arc<String>,
@ -181,7 +195,6 @@ impl Actor for ChatClient {
} }
impl ChatClient { impl ChatClient {
fn send_text(&mut self) -> bool { fn send_text(&mut self) -> bool {
self.sent += self.payload.len(); self.sent += self.payload.len();
@ -193,7 +206,8 @@ impl ChatClient {
let max_payload_size = self.max_payload_size; let max_payload_size = self.max_payload_size;
Arbiter::handle().spawn( Arbiter::handle().spawn(
ws::Client::new(&self.url).connect() ws::Client::new(&self.url)
.connect()
.map_err(|e| { .map_err(|e| {
println!("Error: {}", e); println!("Error: {}", e);
Arbiter::system().do_send(actix::msgs::SystemExit(0)); Arbiter::system().do_send(actix::msgs::SystemExit(0));
@ -202,17 +216,18 @@ impl ChatClient {
.map(move |(reader, writer)| { .map(move |(reader, writer)| {
let addr: Addr<Syn, _> = ChatClient::create(move |ctx| { let addr: Addr<Syn, _> = ChatClient::create(move |ctx| {
ChatClient::add_stream(reader, ctx); ChatClient::add_stream(reader, ctx);
ChatClient{url: ws, ChatClient {
conn: writer, url: ws,
payload: pl, conn: writer,
bin: bin, payload: pl,
ts: time::precise_time_ns(), bin: bin,
perf_counters: perf_counters, ts: time::precise_time_ns(),
sent: 0, perf_counters: perf_counters,
max_payload_size: max_payload_size, sent: 0,
max_payload_size: max_payload_size,
} }
}); });
}) }),
); );
false false
} else { } else {
@ -229,7 +244,6 @@ impl ChatClient {
/// Handle server websocket messages /// Handle server websocket messages
impl StreamHandler<ws::Message, ws::ProtocolError> for ChatClient { impl StreamHandler<ws::Message, ws::ProtocolError> for ChatClient {
fn finished(&mut self, ctx: &mut Context<Self>) { fn finished(&mut self, ctx: &mut Context<Self>) {
ctx.stop() ctx.stop()
} }
@ -239,25 +253,25 @@ impl StreamHandler<ws::Message, ws::ProtocolError> for ChatClient {
ws::Message::Text(txt) => { ws::Message::Text(txt) => {
if txt == self.payload.as_ref().as_str() { if txt == self.payload.as_ref().as_str() {
self.perf_counters.register_request(); self.perf_counters.register_request();
self.perf_counters.register_latency(time::precise_time_ns() - self.ts); self.perf_counters
.register_latency(time::precise_time_ns() - self.ts);
if !self.send_text() { if !self.send_text() {
ctx.stop(); ctx.stop();
} }
} else { } else {
println!("not eaqual"); println!("not eaqual");
} }
}, }
_ => () _ => (),
} }
} }
} }
pub struct PerfCounters { pub struct PerfCounters {
req: AtomicUsize, req: AtomicUsize,
conn: AtomicUsize, conn: AtomicUsize,
lat: AtomicUsize, lat: AtomicUsize,
lat_max: AtomicUsize lat_max: AtomicUsize,
} }
impl PerfCounters { impl PerfCounters {
@ -299,7 +313,11 @@ impl PerfCounters {
self.lat.fetch_add(nanos, Ordering::SeqCst); self.lat.fetch_add(nanos, Ordering::SeqCst);
loop { loop {
let current = self.lat_max.load(Ordering::SeqCst); let current = self.lat_max.load(Ordering::SeqCst);
if current >= nanos || self.lat_max.compare_and_swap(current, nanos, Ordering::SeqCst) == current { if current >= nanos
|| self.lat_max
.compare_and_swap(current, nanos, Ordering::SeqCst)
== current
{
break; break;
} }
} }