1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-08-27 15:35:53 +02:00

Compare commits

..

4 Commits

Author SHA1 Message Date
Harrison
7880d5e2bf Remove extra '#' so docs render correctly. (#1338) 2020-02-07 23:01:23 +09:00
0x1793d1
c23020d266 Fix extra line feed (#1206) 2019-12-09 22:40:37 +06:00
Nikolay Kim
1d45639bed prepare actix-multipart release 2019-12-07 20:02:46 +06:00
Alexander Larsson
6a672c9097 actix-multipart: Fix multipart boundary reading (#1189)
* actix-multipart: Fix multipart boundary reading

If we're not ready to read the first line after the multipart field
(which should be a "\r\n" line) then return NotReady instead of Ready(None)
so that we will get called again to read that line.

Without this I was getting MultipartError::Boundary from read_boundary()
because it got the "\r\n" line instead of the boundary.

* actix-multipart: Test handling of NotReady

Use a stream that reports NoReady and does partial reads in the test_stream
test. This works now, but failed before the previous commit.
2019-12-07 19:58:38 +06:00
122 changed files with 9693 additions and 10539 deletions

View File

@@ -10,9 +10,9 @@ matrix:
include: include:
- rust: stable - rust: stable
- rust: beta - rust: beta
- rust: nightly-2019-11-20 - rust: nightly-2019-08-10
allow_failures: allow_failures:
- rust: nightly-2019-11-20 - rust: nightly-2019-08-10
env: env:
global: global:
@@ -25,7 +25,7 @@ before_install:
- sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev
before_cache: | before_cache: |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then
RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin
fi fi
@@ -37,8 +37,8 @@ script:
- cargo update - cargo update
- cargo check --all --no-default-features - cargo check --all --no-default-features
- cargo test --all-features --all -- --nocapture - cargo test --all-features --all -- --nocapture
# - cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. - cd actix-http; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd ..
# - cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. - cd awc; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd ..
# Upload docs # Upload docs
after_success: after_success:
@@ -51,7 +51,7 @@ after_success:
echo "Uploaded documentation" echo "Uploaded documentation"
fi fi
- | - |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then
taskset -c 0 cargo tarpaulin --out Xml --all --all-features taskset -c 0 cargo tarpaulin --out Xml --all --all-features
bash <(curl -s https://codecov.io/bash) bash <(curl -s https://codecov.io/bash)
echo "Uploaded code coverage" echo "Uploaded code coverage"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web" name = "actix-web"
version = "2.0.0-alpha.1" version = "1.0.9"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust."
readme = "README.md" readme = "README.md"
@@ -16,7 +16,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"]
edition = "2018" edition = "2018"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "rustls", "brotli", "flate2-zlib", "secure-cookies", "client"] features = ["ssl", "brotli", "flate2-zlib", "secure-cookies", "client", "rust-tls", "uds"]
[badges] [badges]
travis-ci = { repository = "actix/actix-web", branch = "master" } travis-ci = { repository = "actix/actix-web", branch = "master" }
@@ -63,35 +63,37 @@ secure-cookies = ["actix-http/secure-cookies"]
fail = ["actix-http/fail"] fail = ["actix-http/fail"]
# openssl # openssl
openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"] ssl = ["openssl", "actix-server/ssl", "awc/ssl"]
# rustls # rustls
# rustls = ["rust-tls", "actix-server/rustls", "awc/rustls"] rust-tls = ["rustls", "actix-server/rust-tls", "awc/rust-tls"]
# unix domain sockets support
uds = ["actix-server/uds"]
[dependencies] [dependencies]
actix-codec = "0.2.0-alpha.1" actix-codec = "0.1.2"
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
actix-utils = "0.5.0-alpha.1" actix-utils = "0.4.4"
actix-router = "0.1.5" actix-router = "0.1.5"
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.4"
actix-web-codegen = "0.2.0-alpha.1" actix-web-codegen = "0.1.2"
actix-http = "0.3.0-alpha.1" actix-http = "0.2.11"
actix-server = "0.8.0-alpha.1" actix-server = "0.6.1"
actix-server-config = "0.3.0-alpha.1" actix-server-config = "0.1.2"
actix-testing = "0.3.0-alpha.1" actix-testing = "0.1.0"
actix-threadpool = "0.2.0-alpha.1" actix-threadpool = "0.1.1"
awc = { version = "0.3.0-alpha.1", optional = true } awc = { version = "0.2.7", optional = true }
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
encoding_rs = "0.8" encoding_rs = "0.8"
futures = "0.3.1" futures = "0.1.25"
hashbrown = "0.6.3" hashbrown = "0.6.3"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
net2 = "0.2.33" net2 = "0.2.33"
parking_lot = "0.9" parking_lot = "0.9"
pin-project = "0.4.5"
regex = "1.0" regex = "1.0"
serde = { version = "1.0", features=["derive"] } serde = { version = "1.0", features=["derive"] }
serde_json = "1.0" serde_json = "1.0"
@@ -100,17 +102,17 @@ time = "0.1.42"
url = "2.1" url = "2.1"
# ssl support # ssl support
open-ssl = { version="0.10", package="openssl", optional = true } openssl = { version="0.10", optional = true }
rust-tls = { version = "0.16", package="rustls", optional = true } rustls = { version = "0.15", optional = true }
[dev-dependencies] [dev-dependencies]
# actix = "0.8.3" actix = "0.8.3"
actix-connect = "0.3.0-alpha.1" actix-connect = "0.2.2"
actix-http-test = "0.3.0-alpha.1" actix-http-test = "0.2.4"
rand = "0.7" rand = "0.7"
env_logger = "0.6" env_logger = "0.6"
serde_derive = "1.0" serde_derive = "1.0"
tokio-timer = "0.3.0-alpha.6" tokio-timer = "0.2.8"
brotli2 = "0.3.2" brotli2 = "0.3.2"
flate2 = "1.0.2" flate2 = "1.0.2"
@@ -124,28 +126,8 @@ actix-web = { path = "." }
actix-http = { path = "actix-http" } actix-http = { path = "actix-http" }
actix-http-test = { path = "test-server" } actix-http-test = { path = "test-server" }
actix-web-codegen = { path = "actix-web-codegen" } actix-web-codegen = { path = "actix-web-codegen" }
# actix-web-actors = { path = "actix-web-actors" } actix-web-actors = { path = "actix-web-actors" }
actix-session = { path = "actix-session" } actix-session = { path = "actix-session" }
actix-files = { path = "actix-files" } actix-files = { path = "actix-files" }
actix-multipart = { path = "actix-multipart" } actix-multipart = { path = "actix-multipart" }
awc = { path = "awc" } awc = { path = "awc" }
actix-codec = { git = "https://github.com/actix/actix-net.git" }
actix-connect = { git = "https://github.com/actix/actix-net.git" }
actix-rt = { git = "https://github.com/actix/actix-net.git" }
actix-server = { git = "https://github.com/actix/actix-net.git" }
actix-server-config = { git = "https://github.com/actix/actix-net.git" }
actix-service = { git = "https://github.com/actix/actix-net.git" }
actix-testing = { git = "https://github.com/actix/actix-net.git" }
actix-threadpool = { git = "https://github.com/actix/actix-net.git" }
actix-utils = { git = "https://github.com/actix/actix-net.git" }
# actix-codec = { path = "../actix-net/actix-codec" }
# actix-connect = { path = "../actix-net/actix-connect" }
# actix-rt = { path = "../actix-net/actix-rt" }
# actix-server = { path = "../actix-net/actix-server" }
# actix-server-config = { path = "../actix-net/actix-server-config" }
# actix-service = { path = "../actix-net/actix-service" }
# actix-testing = { path = "../actix-net/actix-testing" }
# actix-threadpool = { path = "../actix-net/actix-threadpool" }
# actix-utils = { path = "../actix-net/actix-utils" }

View File

@@ -1,10 +1,3 @@
## 2.0.0
* Sync handlers has been removed. `.to_async()` methtod has been renamed to `.to()`
replace `fn` with `async fn` to convert sync handler to async
## 1.0.1 ## 1.0.1
* Cors middleware has been moved to `actix-cors` crate * Cors middleware has been moved to `actix-cors` crate

View File

@@ -19,16 +19,17 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust.
* [User Guide](https://actix.rs/docs/) * [User Guide](https://actix.rs/docs/)
* [API Documentation (1.0)](https://docs.rs/actix-web/) * [API Documentation (1.0)](https://docs.rs/actix-web/)
* [API Documentation (0.7)](https://docs.rs/actix-web/0.7.19/actix_web/)
* [Chat on gitter](https://gitter.im/actix/actix) * [Chat on gitter](https://gitter.im/actix/actix)
* Cargo package: [actix-web](https://crates.io/crates/actix-web) * Cargo package: [actix-web](https://crates.io/crates/actix-web)
* Minimum supported Rust version: 1.39 or later * Minimum supported Rust version: 1.36 or later
## Example ## Example
```rust ```rust
use actix_web::{web, App, HttpServer, Responder}; use actix_web::{web, App, HttpServer, Responder};
async fn index(info: web::Path<(u32, String)>) -> impl Responder { fn index(info: web::Path<(u32, String)>) -> impl Responder {
format!("Hello {}! id:{}", info.1, info.0) format!("Hello {}! id:{}", info.1, info.0)
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-cors" name = "actix-cors"
version = "0.2.0-alpha.1" version = "0.1.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Cross-origin resource sharing (CORS) for Actix applications." description = "Cross-origin resource sharing (CORS) for Actix applications."
readme = "README.md" readme = "README.md"
@@ -10,14 +10,14 @@ repository = "https://github.com/actix/actix-web.git"
documentation = "https://docs.rs/actix-cors/" documentation = "https://docs.rs/actix-cors/"
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
edition = "2018" edition = "2018"
workspace = ".." #workspace = ".."
[lib] [lib]
name = "actix_cors" name = "actix_cors"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = "2.0.0-alpha.1" actix-web = "1.0.0"
actix-service = "1.0.0-alpha.1" actix-service = "0.4.0"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.3.1" futures = "0.1.25"

View File

@@ -11,7 +11,7 @@
//! use actix_cors::Cors; //! use actix_cors::Cors;
//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer};
//! //!
//! async fn index(req: HttpRequest) -> &'static str { //! fn index(req: HttpRequest) -> &'static str {
//! "Hello world" //! "Hello world"
//! } //! }
//! //!
@@ -23,8 +23,7 @@
//! .allowed_methods(vec!["GET", "POST"]) //! .allowed_methods(vec!["GET", "POST"])
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE) //! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600) //! .max_age(3600))
//! .finish())
//! .service( //! .service(
//! web::resource("/index.html") //! web::resource("/index.html")
//! .route(web::get().to(index)) //! .route(web::get().to(index))
@@ -42,16 +41,16 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{IntoTransform, Service, Transform};
use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse};
use actix_web::error::{Error, ResponseError, Result}; use actix_web::error::{Error, ResponseError, Result};
use actix_web::http::header::{self, HeaderName, HeaderValue}; use actix_web::http::header::{self, HeaderName, HeaderValue};
use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri};
use actix_web::HttpResponse; use actix_web::HttpResponse;
use derive_more::Display; use derive_more::Display;
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, Future, FutureResult};
use futures::Poll;
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Display)] #[derive(Debug, Display)]
@@ -457,9 +456,25 @@ impl Cors {
} }
self self
} }
}
/// Construct cors middleware fn cors<'a>(
pub fn finish(self) -> CorsFactory { parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
impl<S, B> IntoTransform<CorsFactory, S> for Cors
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
fn into_transform(self) -> CorsFactory {
let mut slf = if !self.methods { let mut slf = if !self.methods {
self.allowed_methods(vec![ self.allowed_methods(vec![
Method::GET, Method::GET,
@@ -506,16 +521,6 @@ impl Cors {
} }
} }
fn cors<'a>(
parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
/// `Middleware` for Cross-origin resource sharing support /// `Middleware` for Cross-origin resource sharing support
/// ///
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
@@ -535,7 +540,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = CorsMiddleware<S>; type Transform = CorsMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CorsMiddleware { ok(CorsMiddleware {
@@ -677,12 +682,12 @@ where
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = Either< type Future = Either<
Ready<Result<Self::Response, Error>>, FutureResult<Self::Response, Error>,
LocalBoxFuture<'static, Result<Self::Response, Error>>, Either<S::Future, Box<dyn Future<Item = Self::Response, Error = Error>>>,
>; >;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
@@ -693,7 +698,7 @@ where
.and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_method(req.head()))
.and_then(|_| self.inner.validate_allowed_headers(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head()))
{ {
return Either::Left(ok(req.error_response(e))); return Either::A(ok(req.error_response(e)));
} }
// allowed headers // allowed headers
@@ -746,50 +751,39 @@ where
.finish() .finish()
.into_body(); .into_body();
Either::Left(ok(req.into_response(res))) Either::A(ok(req.into_response(res)))
} else { } else if req.headers().contains_key(&header::ORIGIN) {
if req.headers().contains_key(&header::ORIGIN) { // Only check requests with a origin header.
// Only check requests with a origin header. if let Err(e) = self.inner.validate_origin(req.head()) {
if let Err(e) = self.inner.validate_origin(req.head()) { return Either::A(ok(req.error_response(e)));
return Either::Left(ok(req.error_response(e)));
}
} }
let inner = self.inner.clone(); let inner = self.inner.clone();
let has_origin = req.headers().contains_key(&header::ORIGIN);
let fut = self.service.call(req);
Either::Right( Either::B(Either::B(Box::new(self.service.call(req).and_then(
async move { move |mut res| {
let res = fut.await; if let Some(origin) =
inner.access_control_allow_origin(res.request().head())
{
res.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
};
if has_origin { if let Some(ref expose) = inner.expose_hdrs {
let mut res = res?; res.headers_mut().insert(
if let Some(origin) = header::ACCESS_CONTROL_EXPOSE_HEADERS,
inner.access_control_allow_origin(res.request().head()) HeaderValue::try_from(expose.as_str()).unwrap(),
{ );
res.headers_mut().insert( }
header::ACCESS_CONTROL_ALLOW_ORIGIN, if inner.supports_credentials {
origin.clone(), res.headers_mut().insert(
); header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
}; HeaderValue::from_static("true"),
);
if let Some(ref expose) = inner.expose_hdrs { }
res.headers_mut().insert( if inner.vary_header {
header::ACCESS_CONTROL_EXPOSE_HEADERS, let value =
HeaderValue::try_from(expose.as_str()).unwrap(), if let Some(hdr) = res.headers_mut().get(&header::VARY) {
);
}
if inner.supports_credentials {
res.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if inner.vary_header {
let value = if let Some(hdr) =
res.headers_mut().get(&header::VARY)
{
let mut val: Vec<u8> = let mut val: Vec<u8> =
Vec::with_capacity(hdr.as_bytes().len() + 8); Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes()); val.extend(hdr.as_bytes());
@@ -798,153 +792,159 @@ where
} else { } else {
HeaderValue::from_static("Origin") HeaderValue::from_static("Origin")
}; };
res.headers_mut().insert(header::VARY, value); res.headers_mut().insert(header::VARY, value);
}
Ok(res)
} else {
res
} }
} Ok(res)
.boxed_local(), },
) ))))
} else {
Either::B(Either::A(self.service.call(req)))
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::{service_fn2, Transform}; use actix_service::{IntoService, Transform};
use actix_web::test::{self, block_on, TestRequest}; use actix_web::test::{self, block_on, TestRequest};
use super::*; use super::*;
impl Cors {
fn finish<F, S, B>(self, srv: F) -> CorsMiddleware<S>
where
F: IntoService<S>,
S: Service<
Request = ServiceRequest,
Response = ServiceResponse<B>,
Error = Error,
> + 'static,
S::Future: 'static,
B: 'static,
{
block_on(
IntoTransform::<CorsFactory, S>::into_transform(self)
.new_transform(srv.into_service()),
)
.unwrap()
}
}
#[test] #[test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
fn cors_validates_illegal_allow_credentials() { fn cors_validates_illegal_allow_credentials() {
let _cors = Cors::new().supports_credentials().send_wildcard().finish(); let _cors = Cors::new()
.supports_credentials()
.send_wildcard()
.finish(test::ok_service());
} }
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
block_on(async { let mut cors = Cors::new().finish(test::ok_service());
let mut cors = Cors::new() let req = TestRequest::with_header("Origin", "https://www.example.com")
.finish() .to_srv_request();
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn default() { fn default() {
block_on(async { let mut cors =
let mut cors = Cors::default() block_on(Cors::default().new_transform(test::ok_service())).unwrap();
.new_transform(test::ok_service()) let req = TestRequest::with_header("Origin", "https://www.example.com")
.await .to_srv_request();
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn test_preflight() { fn test_preflight() {
block_on(async { let mut cors = Cors::new()
let mut cors = Cors::new() .send_wildcard()
.send_wildcard() .max_age(3600)
.max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_header(header::CONTENT_TYPE)
.allowed_header(header::CONTENT_TYPE) .finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); assert!(cors.inner.validate_allowed_headers(req.head()).is_ok());
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header( .header(
header::ACCESS_CONTROL_REQUEST_HEADERS, header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT", "AUTHORIZATION,ACCEPT",
) )
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
.get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!(
&b"3600"[..],
resp.headers()
.get(&header::ACCESS_CONTROL_MAX_AGE)
.unwrap()
.as_bytes()
);
let hdr = resp
.headers()
.get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap() .unwrap()
.to_str() .as_bytes()
.unwrap(); );
assert!(hdr.contains("authorization")); assert_eq!(
assert!(hdr.contains("accept")); &b"3600"[..],
assert!(hdr.contains("content-type")); resp.headers()
.get(&header::ACCESS_CONTROL_MAX_AGE)
let methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.unwrap() .unwrap()
.to_str() .as_bytes()
.unwrap(); );
assert!(methods.contains("POST")); let hdr = resp
assert!(methods.contains("GET")); .headers()
assert!(methods.contains("OPTIONS")); .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap()
.to_str()
.unwrap();
assert!(hdr.contains("authorization"));
assert!(hdr.contains("accept"));
assert!(hdr.contains("content-type"));
Rc::get_mut(&mut cors.inner).unwrap().preflight = false; let methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.unwrap()
.to_str()
.unwrap();
assert!(methods.contains("POST"));
assert!(methods.contains("GET"));
assert!(methods.contains("OPTIONS"));
let req = TestRequest::with_header("Origin", "https://www.example.com") Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
)
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await; let req = TestRequest::with_header("Origin", "https://www.example.com")
assert_eq!(resp.status(), StatusCode::OK); .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
}) .header(
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
)
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK);
} }
// #[test] // #[test]
@@ -960,254 +960,216 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
block_on(async { let cors = Cors::new()
let cors = Cors::new() .allowed_origin("https://www.example.com")
.allowed_origin("https://www.example.com") .finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.unknown.com") let req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
cors.inner.validate_origin(req.head()).unwrap(); cors.inner.validate_origin(req.head()).unwrap();
cors.inner.validate_allowed_method(req.head()).unwrap(); cors.inner.validate_allowed_method(req.head()).unwrap();
cors.inner.validate_allowed_headers(req.head()).unwrap(); cors.inner.validate_allowed_headers(req.head()).unwrap();
})
} }
#[test] #[test]
fn test_validate_origin() { fn test_validate_origin() {
block_on(async { let mut cors = Cors::new()
let mut cors = Cors::new() .allowed_origin("https://www.example.com")
.allowed_origin("https://www.example.com") .finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn test_no_origin_response() { fn test_no_origin_response() {
block_on(async { let mut cors = Cors::new().disable_preflight().finish(test::ok_service());
let mut cors = Cors::new()
.disable_preflight()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request(); let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert!(resp assert!(resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none());
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
assert_eq!(
&b"https://www.example.com"[..],
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none()); .unwrap()
.as_bytes()
let req = TestRequest::with_header("Origin", "https://www.example.com") );
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://www.example.com"[..],
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
})
} }
#[test] #[test]
fn test_response() { fn test_response() {
block_on(async { let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let mut cors = Cors::new()
let mut cors = Cors::new() .send_wildcard()
.send_wildcard() .disable_preflight()
.disable_preflight() .max_age(3600)
.max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(exposed_headers.clone())
.allowed_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE)
.allowed_header(header::CONTENT_TYPE) .finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!(
&b"Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
{
let headers = resp
.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.unwrap()
.to_str()
.unwrap()
.split(',')
.map(|s| s.trim())
.collect::<Vec<&str>>();
for h in exposed_headers {
assert!(headers.contains(&h.as_str()));
}
}
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new()
.send_wildcard()
.disable_preflight()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.finish()
.new_transform(service_fn2(|req: ServiceRequest| {
ok(req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(),
))
}))
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
let mut cors = Cors::new()
.disable_vary_header()
.allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com")
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
let origins_str = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.to_str() .as_bytes()
.unwrap(); );
assert_eq!(
&b"Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
assert_eq!("https://www.example.com", origins_str); {
}) let headers = resp
.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.unwrap()
.to_str()
.unwrap()
.split(',')
.map(|s| s.trim())
.collect::<Vec<&str>>();
for h in exposed_headers {
assert!(headers.contains(&h.as_str()));
}
}
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new()
.send_wildcard()
.disable_preflight()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.finish(|req: ServiceRequest| {
req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(),
)
});
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
assert_eq!(
&b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
let mut cors = Cors::new()
.disable_vary_header()
.allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com")
.finish(test::ok_service());
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let origins_str = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.to_str()
.unwrap();
assert_eq!("https://www.example.com", origins_str);
} }
#[test] #[test]
fn test_multiple_origins() { fn test_multiple_origins() {
block_on(async { let mut cors = Cors::new()
let mut cors = Cors::new() .allowed_origin("https://example.com")
.allowed_origin("https://example.com") .allowed_origin("https://example.org")
.allowed_origin("https://example.org") .allowed_methods(vec![Method::GET])
.allowed_methods(vec![Method::GET]) .finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
let req = TestRequest::with_header("Origin", "https://example.org") let req = TestRequest::with_header("Origin", "https://example.org")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
#[test] #[test]
fn test_multiple_origins_preflight() { fn test_multiple_origins_preflight() {
block_on(async { let mut cors = Cors::new()
let mut cors = Cors::new() .allowed_origin("https://example.com")
.allowed_origin("https://example.com") .allowed_origin("https://example.org")
.allowed_origin("https://example.org") .allowed_methods(vec![Method::GET])
.allowed_methods(vec![Method::GET]) .finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
let req = TestRequest::with_header("Origin", "https://example.org") let req = TestRequest::with_header("Origin", "https://example.org")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-files" name = "actix-files"
version = "0.2.0-alpha.1" version = "0.1.7"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Static files support for actix web." description = "Static files support for actix web."
readme = "README.md" readme = "README.md"
@@ -18,12 +18,12 @@ name = "actix_files"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "2.0.0-alpha.1", default-features = false } actix-web = { version = "1.0.8", default-features = false }
actix-http = "0.3.0-alpha.1" actix-http = "0.2.11"
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
bitflags = "1" bitflags = "1"
bytes = "0.4" bytes = "0.4"
futures = "0.3.1" futures = "0.1.25"
derive_more = "0.15.0" derive_more = "0.15.0"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
@@ -32,4 +32,4 @@ percent-encoding = "2.1"
v_htmlescape = "0.4" v_htmlescape = "0.4"
[dev-dependencies] [dev-dependencies]
actix-web = { version = "2.0.0-alpha.1", features=["openssl"] } actix-web = { version = "1.0.8", features=["ssl"] }

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,6 @@ use actix_web::http::header::{
use actix_web::http::{ContentEncoding, StatusCode}; use actix_web::http::{ContentEncoding, StatusCode};
use actix_web::middleware::BodyEncoding; use actix_web::middleware::BodyEncoding;
use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder}; use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder};
use futures::future::{ready, Ready};
use crate::range::HttpRange; use crate::range::HttpRange;
use crate::ChunkedReadFile; use crate::ChunkedReadFile;
@@ -256,8 +255,62 @@ impl NamedFile {
pub(crate) fn last_modified(&self) -> Option<header::HttpDate> { pub(crate) fn last_modified(&self) -> Option<header::HttpDate> {
self.modified.map(|mtime| mtime.into()) self.modified.map(|mtime| mtime.into())
} }
}
pub fn into_response(self, req: &HttpRequest) -> Result<HttpResponse, Error> { impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Result<HttpResponse, Error>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
if self.status_code != StatusCode::OK { if self.status_code != StatusCode::OK {
let mut resp = HttpResponse::build(self.status_code); let mut resp = HttpResponse::build(self.status_code);
resp.set(header::ContentType(self.content_type.clone())) resp.set(header::ContentType(self.content_type.clone()))
@@ -389,67 +442,8 @@ impl NamedFile {
counter: 0, counter: 0,
}; };
if offset != 0 || length != self.md.len() { if offset != 0 || length != self.md.len() {
Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)) return Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader));
} else { };
Ok(resp.body(SizedStream::new(length, reader))) Ok(resp.body(SizedStream::new(length, reader)))
}
}
}
impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Ready<Result<HttpResponse, Error>>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
ready(self.into_response(req))
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-framed" name = "actix-framed"
version = "0.3.0-alpha.1" version = "0.2.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix framed app server" description = "Actix framed app server"
readme = "README.md" readme = "README.md"
@@ -20,20 +20,19 @@ name = "actix_framed"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-codec = "0.2.0-alpha.1" actix-codec = "0.1.2"
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
actix-router = "0.1.2" actix-router = "0.1.2"
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.2"
actix-http = "0.3.0-alpha.1" actix-http = "0.2.7"
actix-server-config = "0.3.0-alpha.1" actix-server-config = "0.1.2"
bytes = "0.4" bytes = "0.4"
futures = "0.3.1" futures = "0.1.25"
pin-project = "0.4.6"
log = "0.4" log = "0.4"
[dev-dependencies] [dev-dependencies]
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } actix-server = { version = "0.6.0", features=["ssl"] }
actix-connect = { version = "0.3.0-alpha.1", features=["openssl"] } actix-connect = { version = "0.2.0", features=["ssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } actix-http-test = { version = "0.2.4", features=["ssl"] }
actix-utils = "0.5.0-alpha.1" actix-utils = "0.4.4"

View File

@@ -1,24 +1,21 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::h1::{Codec, SendResponse}; use actix_http::h1::{Codec, SendResponse};
use actix_http::{Error, Request, Response}; use actix_http::{Error, Request, Response};
use actix_router::{Path, Router, Url}; use actix_router::{Path, Router, Url};
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_service::{IntoNewService, NewService, Service};
use futures::future::{ok, FutureExt, LocalBoxFuture}; use futures::{Async, Future, Poll};
use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService}; use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService};
use crate::request::FramedRequest; use crate::request::FramedRequest;
use crate::state::State; use crate::state::State;
type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>; type BoxedResponse = Box<dyn Future<Item = (), Error = Error>>;
pub trait HttpServiceFactory { pub trait HttpServiceFactory {
type Factory: ServiceFactory; type Factory: NewService;
fn path(&self) -> &str; fn path(&self) -> &str;
@@ -51,19 +48,19 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
pub fn service<U>(mut self, factory: U) -> Self pub fn service<U>(mut self, factory: U) -> Self
where where
U: HttpServiceFactory, U: HttpServiceFactory,
U::Factory: ServiceFactory< U::Factory: NewService<
Config = (), Config = (),
Request = FramedRequest<T, S>, Request = FramedRequest<T, S>,
Response = (), Response = (),
Error = Error, Error = Error,
InitError = (), InitError = (),
> + 'static, > + 'static,
<U::Factory as ServiceFactory>::Future: 'static, <U::Factory as NewService>::Future: 'static,
<U::Factory as ServiceFactory>::Service: Service< <U::Factory as NewService>::Service: Service<
Request = FramedRequest<T, S>, Request = FramedRequest<T, S>,
Response = (), Response = (),
Error = Error, Error = Error,
Future = LocalBoxFuture<'static, Result<(), Error>>, Future = Box<dyn Future<Item = (), Error = Error>>,
>, >,
{ {
let path = factory.path().to_string(); let path = factory.path().to_string();
@@ -73,12 +70,12 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
} }
} }
impl<T, S> IntoServiceFactory<FramedAppFactory<T, S>> for FramedApp<T, S> impl<T, S> IntoNewService<FramedAppFactory<T, S>> for FramedApp<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
S: 'static, S: 'static,
{ {
fn into_factory(self) -> FramedAppFactory<T, S> { fn into_new_service(self) -> FramedAppFactory<T, S> {
FramedAppFactory { FramedAppFactory {
state: self.state, state: self.state,
services: Rc::new(self.services), services: Rc::new(self.services),
@@ -92,9 +89,9 @@ pub struct FramedAppFactory<T, S> {
services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>, services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>,
} }
impl<T, S> ServiceFactory for FramedAppFactory<T, S> impl<T, S> NewService for FramedAppFactory<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
S: 'static, S: 'static,
{ {
type Config = ServerConfig; type Config = ServerConfig;
@@ -131,30 +128,28 @@ pub struct CreateService<T, S> {
enum CreateServiceItem<T, S> { enum CreateServiceItem<T, S> {
Future( Future(
Option<String>, Option<String>,
LocalBoxFuture<'static, Result<BoxedHttpService<FramedRequest<T, S>>, ()>>, Box<dyn Future<Item = BoxedHttpService<FramedRequest<T, S>>, Error = ()>>,
), ),
Service(String, BoxedHttpService<FramedRequest<T, S>>), Service(String, BoxedHttpService<FramedRequest<T, S>>),
} }
impl<S: 'static, T: 'static> Future for CreateService<T, S> impl<S: 'static, T: 'static> Future for CreateService<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite,
{ {
type Output = Result<FramedAppService<T, S>, ()>; type Item = FramedAppService<T, S>;
type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut done = true; let mut done = true;
// poll http services // poll http services
for item in &mut self.fut { for item in &mut self.fut {
let res = match item { let res = match item {
CreateServiceItem::Future(ref mut path, ref mut fut) => { CreateServiceItem::Future(ref mut path, ref mut fut) => {
match Pin::new(fut).poll(cx) { match fut.poll()? {
Poll::Ready(Ok(service)) => { Async::Ready(service) => Some((path.take().unwrap(), service)),
Some((path.take().unwrap(), service)) Async::NotReady => {
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
done = false; done = false;
None None
} }
@@ -181,12 +176,12 @@ where
} }
router router
}); });
Poll::Ready(Ok(FramedAppService { Ok(Async::Ready(FramedAppService {
router: router.finish(), router: router.finish(),
state: self.state.clone(), state: self.state.clone(),
})) }))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -198,15 +193,15 @@ pub struct FramedAppService<T, S> {
impl<S: 'static, T: 'static> Service for FramedAppService<T, S> impl<S: 'static, T: 'static> Service for FramedAppService<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite,
{ {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = BoxedResponse; type Future = BoxedResponse;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future { fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
@@ -215,8 +210,8 @@ where
if let Some((srv, _info)) = self.router.recognize_mut(&mut path) { if let Some((srv, _info)) = self.router.recognize_mut(&mut path) {
return srv.call(FramedRequest::new(req, framed, path, self.state.clone())); return srv.call(FramedRequest::new(req, framed, path, self.state.clone()));
} }
SendResponse::new(framed, Response::NotFound().finish()) Box::new(
.then(|_| ok(())) SendResponse::new(framed, Response::NotFound().finish()).then(|_| Ok(())),
.boxed_local() )
} }
} }

View File

@@ -1,38 +1,36 @@
use std::task::{Context, Poll};
use actix_http::Error; use actix_http::Error;
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{FutureExt, LocalBoxFuture}; use futures::{Future, Poll};
pub(crate) type BoxedHttpService<Req> = Box< pub(crate) type BoxedHttpService<Req> = Box<
dyn Service< dyn Service<
Request = Req, Request = Req,
Response = (), Response = (),
Error = Error, Error = Error,
Future = LocalBoxFuture<'static, Result<(), Error>>, Future = Box<dyn Future<Item = (), Error = Error>>,
>, >,
>; >;
pub(crate) type BoxedHttpNewService<Req> = Box< pub(crate) type BoxedHttpNewService<Req> = Box<
dyn ServiceFactory< dyn NewService<
Config = (), Config = (),
Request = Req, Request = Req,
Response = (), Response = (),
Error = Error, Error = Error,
InitError = (), InitError = (),
Service = BoxedHttpService<Req>, Service = BoxedHttpService<Req>,
Future = LocalBoxFuture<'static, Result<BoxedHttpService<Req>, ()>>, Future = Box<dyn Future<Item = BoxedHttpService<Req>, Error = ()>>,
>, >,
>; >;
pub(crate) struct HttpNewService<T: ServiceFactory>(T); pub(crate) struct HttpNewService<T: NewService>(T);
impl<T> HttpNewService<T> impl<T> HttpNewService<T>
where where
T: ServiceFactory<Response = (), Error = Error>, T: NewService<Response = (), Error = Error>,
T::Response: 'static, T::Response: 'static,
T::Future: 'static, T::Future: 'static,
T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static, T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static,
<T::Service as Service>::Future: 'static, <T::Service as Service>::Future: 'static,
{ {
pub fn new(service: T) -> Self { pub fn new(service: T) -> Self {
@@ -40,12 +38,12 @@ where
} }
} }
impl<T> ServiceFactory for HttpNewService<T> impl<T> NewService for HttpNewService<T>
where where
T: ServiceFactory<Config = (), Response = (), Error = Error>, T: NewService<Config = (), Response = (), Error = Error>,
T::Request: 'static, T::Request: 'static,
T::Future: 'static, T::Future: 'static,
T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static, T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static,
<T::Service as Service>::Future: 'static, <T::Service as Service>::Future: 'static,
{ {
type Config = (); type Config = ();
@@ -54,19 +52,13 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Service = BoxedHttpService<T::Request>; type Service = BoxedHttpService<T::Request>;
type Future = LocalBoxFuture<'static, Result<Self::Service, ()>>; type Future = Box<dyn Future<Item = Self::Service, Error = ()>>;
fn new_service(&self, _: &()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
let fut = self.0.new_service(&()); Box::new(self.0.new_service(&()).map_err(|_| ()).and_then(|service| {
let service: BoxedHttpService<_> = Box::new(HttpServiceWrapper { service });
async move { Ok(service)
fut.await.map_err(|_| ()).map(|service| { }))
let service: BoxedHttpService<_> =
Box::new(HttpServiceWrapper { service });
service
})
}
.boxed_local()
} }
} }
@@ -78,7 +70,7 @@ impl<T> Service for HttpServiceWrapper<T>
where where
T: Service< T: Service<
Response = (), Response = (),
Future = LocalBoxFuture<'static, Result<(), Error>>, Future = Box<dyn Future<Item = (), Error = Error>>,
Error = Error, Error = Error,
>, >,
T::Request: 'static, T::Request: 'static,
@@ -86,10 +78,10 @@ where
type Request = T::Request; type Request = T::Request;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<(), Error>>; type Future = Box<dyn Future<Item = (), Error = Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {

View File

@@ -1,12 +1,11 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{http::Method, Error}; use actix_http::{http::Method, Error};
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, FutureResult};
use futures::{Async, Future, IntoFuture, Poll};
use log::error; use log::error;
use crate::app::HttpServiceFactory; use crate::app::HttpServiceFactory;
@@ -16,11 +15,11 @@ use crate::request::FramedRequest;
/// ///
/// Route uses builder-like pattern for configuration. /// Route uses builder-like pattern for configuration.
/// If handler is not explicitly set, default *404 Not Found* handler is used. /// If handler is not explicitly set, default *404 Not Found* handler is used.
pub struct FramedRoute<Io, S, F = (), R = (), E = ()> { pub struct FramedRoute<Io, S, F = (), R = ()> {
handler: F, handler: F,
pattern: String, pattern: String,
methods: Vec<Method>, methods: Vec<Method>,
state: PhantomData<(Io, S, R, E)>, state: PhantomData<(Io, S, R)>,
} }
impl<Io, S> FramedRoute<Io, S> { impl<Io, S> FramedRoute<Io, S> {
@@ -54,12 +53,12 @@ impl<Io, S> FramedRoute<Io, S> {
self self
} }
pub fn to<F, R, E>(self, handler: F) -> FramedRoute<Io, S, F, R, E> pub fn to<F, R>(self, handler: F) -> FramedRoute<Io, S, F, R>
where where
F: FnMut(FramedRequest<Io, S>) -> R, F: FnMut(FramedRequest<Io, S>) -> R,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
R::Future: 'static,
E: fmt::Debug, R::Error: fmt::Debug,
{ {
FramedRoute { FramedRoute {
handler, handler,
@@ -70,14 +69,15 @@ impl<Io, S> FramedRoute<Io, S> {
} }
} }
impl<Io, S, F, R, E> HttpServiceFactory for FramedRoute<Io, S, F, R, E> impl<Io, S, F, R> HttpServiceFactory for FramedRoute<Io, S, F, R>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
E: fmt::Display, R::Future: 'static,
R::Error: fmt::Display,
{ {
type Factory = FramedRouteFactory<Io, S, F, R, E>; type Factory = FramedRouteFactory<Io, S, F, R>;
fn path(&self) -> &str { fn path(&self) -> &str {
&self.pattern &self.pattern
@@ -92,26 +92,27 @@ where
} }
} }
pub struct FramedRouteFactory<Io, S, F, R, E> { pub struct FramedRouteFactory<Io, S, F, R> {
handler: F, handler: F,
methods: Vec<Method>, methods: Vec<Method>,
_t: PhantomData<(Io, S, R, E)>, _t: PhantomData<(Io, S, R)>,
} }
impl<Io, S, F, R, E> ServiceFactory for FramedRouteFactory<Io, S, F, R, E> impl<Io, S, F, R> NewService for FramedRouteFactory<Io, S, F, R>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
E: fmt::Display, R::Future: 'static,
R::Error: fmt::Display,
{ {
type Config = (); type Config = ();
type Request = FramedRequest<Io, S>; type Request = FramedRequest<Io, S>;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Service = FramedRouteService<Io, S, F, R, E>; type Service = FramedRouteService<Io, S, F, R>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
ok(FramedRouteService { ok(FramedRouteService {
@@ -122,38 +123,35 @@ where
} }
} }
pub struct FramedRouteService<Io, S, F, R, E> { pub struct FramedRouteService<Io, S, F, R> {
handler: F, handler: F,
methods: Vec<Method>, methods: Vec<Method>,
_t: PhantomData<(Io, S, R, E)>, _t: PhantomData<(Io, S, R)>,
} }
impl<Io, S, F, R, E> Service for FramedRouteService<Io, S, F, R, E> impl<Io, S, F, R> Service for FramedRouteService<Io, S, F, R>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
E: fmt::Display, R::Future: 'static,
R::Error: fmt::Display,
{ {
type Request = FramedRequest<Io, S>; type Request = FramedRequest<Io, S>;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<(), Error>>; type Future = Box<dyn Future<Item = (), Error = Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future { fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future {
let fut = (self.handler)(req); Box::new((self.handler)(req).into_future().then(|res| {
async move {
let res = fut.await;
if let Err(e) = res { if let Err(e) = res {
error!("Error in request handler: {}", e); error!("Error in request handler: {}", e);
} }
Ok(()) Ok(())
} }))
.boxed_local()
} }
} }

View File

@@ -1,6 +1,4 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::body::BodySize; use actix_http::body::BodySize;
@@ -8,9 +6,9 @@ use actix_http::error::ResponseError;
use actix_http::h1::{Codec, Message}; use actix_http::h1::{Codec, Message};
use actix_http::ws::{verify_handshake, HandshakeError}; use actix_http::ws::{verify_handshake, HandshakeError};
use actix_http::{Request, Response}; use actix_http::{Request, Response};
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{err, ok, Either, Ready}; use futures::future::{ok, Either, FutureResult};
use futures::Future; use futures::{Async, Future, IntoFuture, Poll, Sink};
/// Service that verifies incoming request if it is valid websocket /// Service that verifies incoming request if it is valid websocket
/// upgrade request. In case of error returns `HandshakeError` /// upgrade request. In case of error returns `HandshakeError`
@@ -24,14 +22,14 @@ impl<T, C> Default for VerifyWebSockets<T, C> {
} }
} }
impl<T, C> ServiceFactory for VerifyWebSockets<T, C> { impl<T, C> NewService for VerifyWebSockets<T, C> {
type Config = C; type Config = C;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>); type Error = (HandshakeError, Framed<T, Codec>);
type InitError = (); type InitError = ();
type Service = VerifyWebSockets<T, C>; type Service = VerifyWebSockets<T, C>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &C) -> Self::Future { fn new_service(&self, _: &C) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData }) ok(VerifyWebSockets { _t: PhantomData })
@@ -42,16 +40,16 @@ impl<T, C> Service for VerifyWebSockets<T, C> {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>); type Error = (HandshakeError, Framed<T, Codec>);
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future { fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(req.head()) { match verify_handshake(req.head()) {
Err(e) => err((e, framed)), Err(e) => Err((e, framed)).into_future(),
Ok(_) => ok((req, framed)), Ok(_) => Ok((req, framed)).into_future(),
} }
} }
} }
@@ -69,9 +67,9 @@ where
} }
} }
impl<T, R, E, C> ServiceFactory for SendError<T, R, E, C> impl<T, R, E, C> NewService for SendError<T, R, E, C>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
R: 'static, R: 'static,
E: ResponseError + 'static, E: ResponseError + 'static,
{ {
@@ -81,7 +79,7 @@ where
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
type InitError = (); type InitError = ();
type Service = SendError<T, R, E, C>; type Service = SendError<T, R, E, C>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &C) -> Self::Future { fn new_service(&self, _: &C) -> Self::Future {
ok(SendError(PhantomData)) ok(SendError(PhantomData))
@@ -90,25 +88,25 @@ where
impl<T, R, E, C> Service for SendError<T, R, E, C> impl<T, R, E, C> Service for SendError<T, R, E, C>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
R: 'static, R: 'static,
E: ResponseError + 'static, E: ResponseError + 'static,
{ {
type Request = Result<R, (E, Framed<T, Codec>)>; type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R; type Response = R;
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
type Future = Either<Ready<Result<R, (E, Framed<T, Codec>)>>, SendErrorFut<T, R, E>>; type Future = Either<FutureResult<R, (E, Framed<T, Codec>)>, SendErrorFut<T, R, E>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future { fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future {
match req { match req {
Ok(r) => Either::Left(ok(r)), Ok(r) => Either::A(ok(r)),
Err((e, framed)) => { Err((e, framed)) => {
let res = e.error_response().drop_body(); let res = e.error_response().drop_body();
Either::Right(SendErrorFut { Either::B(SendErrorFut {
framed: Some(framed), framed: Some(framed),
res: Some((res, BodySize::Empty).into()), res: Some((res, BodySize::Empty).into()),
err: Some(e), err: Some(e),
@@ -119,7 +117,6 @@ where
} }
} }
#[pin_project::pin_project]
pub struct SendErrorFut<T, R, E> { pub struct SendErrorFut<T, R, E> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>,
@@ -130,27 +127,23 @@ pub struct SendErrorFut<T, R, E> {
impl<T, R, E> Future for SendErrorFut<T, R, E> impl<T, R, E> Future for SendErrorFut<T, R, E>
where where
E: ResponseError, E: ResponseError,
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite,
{ {
type Output = Result<R, (E, Framed<T, Codec>)>; type Item = R;
type Error = (E, Framed<T, Codec>);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(res) = self.res.take() { if let Some(res) = self.res.take() {
if self.framed.as_mut().unwrap().write(res).is_err() { if self.framed.as_mut().unwrap().force_send(res).is_err() {
return Poll::Ready(Err(( return Err((self.err.take().unwrap(), self.framed.take().unwrap()));
self.err.take().unwrap(),
self.framed.take().unwrap(),
)));
} }
} }
match self.framed.as_mut().unwrap().flush(cx) { match self.framed.as_mut().unwrap().poll_complete() {
Poll::Ready(Ok(_)) => { Ok(Async::Ready(_)) => {
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) Err((self.err.take().unwrap(), self.framed.take().unwrap()))
} }
Poll::Ready(Err(_)) => { Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())),
}
Poll::Pending => Poll::Pending,
} }
} }
} }

View File

@@ -1,6 +1,4 @@
//! Various helpers for Actix applications to use during testing. //! Various helpers for Actix applications to use during testing.
use std::future::Future;
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::h1::Codec; use actix_http::h1::Codec;
use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; use actix_http::http::header::{Header, HeaderName, IntoHeaderValue};
@@ -8,6 +6,7 @@ use actix_http::http::{HttpTryFrom, Method, Uri, Version};
use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest}; use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest};
use actix_router::{Path, Url}; use actix_router::{Path, Url};
use actix_rt::Runtime; use actix_rt::Runtime;
use futures::IntoFuture;
use crate::{FramedRequest, State}; use crate::{FramedRequest, State};
@@ -122,10 +121,10 @@ impl<S> TestRequest<S> {
pub fn run<F, R, I, E>(self, f: F) -> Result<I, E> pub fn run<F, R, I, E>(self, f: F) -> Result<I, E>
where where
F: FnOnce(FramedRequest<TestBuffer, S>) -> R, F: FnOnce(FramedRequest<TestBuffer, S>) -> R,
R: Future<Output = Result<I, E>>, R: IntoFuture<Item = I, Error = E>,
{ {
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
rt.block_on(f(self.finish())) rt.block_on(f(self.finish()).into_future())
} }
} }

View File

@@ -1,31 +1,30 @@
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response}; use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response};
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory}; use actix_service::{IntoNewService, NewService};
use actix_utils::framed::FramedTransport; use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{future, SinkExt, StreamExt}; use futures::future::{self, ok};
use futures::{Future, Sink, Stream};
use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets}; use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets};
async fn ws_service<T: AsyncRead + AsyncWrite>( fn ws_service<T: AsyncRead + AsyncWrite>(
req: FramedRequest<T>, req: FramedRequest<T>,
) -> Result<(), Error> { ) -> impl Future<Item = (), Error = Error> {
let (req, mut framed, _) = req.into_parts(); let (req, framed, _) = req.into_parts();
let res = ws::handshake(req.head()).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())
.await .map_err(|_| panic!())
.unwrap(); .and_then(|framed| {
FramedTransport::new(framed.into_framed(ws::Codec::new()), service) FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.await .map_err(|_| panic!())
.unwrap(); })
Ok(())
} }
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> { fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
let msg = match msg { let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
@@ -35,129 +34,108 @@ async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
ws::Frame::Close(reason) => ws::Message::Close(reason), ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(), _ => panic!(),
}; };
Ok(msg) ok(msg)
} }
#[test] #[test]
fn test_simple() { fn test_simple() {
block_on(async { let mut srv = TestServer::new(|| {
let mut srv = TestServer::start(|| { HttpService::build()
HttpService::build() .upgrade(
.upgrade( FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)),
FramedApp::new() )
.service(FramedRoute::get("/index.html").to(ws_service)), .finish(|_| future::ok::<_, Error>(Response::NotFound()))
) });
.finish(|_| future::ok::<_, Error>(Response::NotFound()))
});
assert!(srv.ws_at("/test").await.is_err()); assert!(srv.ws_at("/test").is_err());
// client service // client service
let mut framed = srv.ws_at("/index.html").await.unwrap(); let framed = srv.ws_at("/index.html").unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await; assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Binary("text".into())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await; assert_eq!(
assert_eq!( item,
item.unwrap().unwrap(), Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) );
);
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let (item, mut framed) = framed.into_future().await; .block_on(framed.send(ws::Message::Ping("text".into())))
assert_eq!( .unwrap();
item.unwrap().unwrap(), let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
ws::Frame::Pong("text".to_string().into()) assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
);
framed let framed = srv
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.await .unwrap();
.unwrap();
let (item, _) = framed.into_future().await; let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Close(Some(ws::CloseCode::Normal.into())) Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
); );
})
} }
#[test] #[test]
fn test_service() { fn test_service() {
block_on(async { let mut srv = TestServer::new(|| {
let mut srv = TestServer::start(|| { actix_http::h1::OneRequest::new().map_err(|_| ()).and_then(
pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then( VerifyWebSockets::default()
pipeline_factory( .then(SendError::default())
pipeline_factory(VerifyWebSockets::default()) .map_err(|_| ())
.then(SendError::default())
.map_err(|_| ()),
)
.and_then( .and_then(
FramedApp::new() FramedApp::new()
.service(FramedRoute::get("/index.html").to(ws_service)) .service(FramedRoute::get("/index.html").to(ws_service))
.into_factory() .into_new_service()
.map_err(|_| ()), .map_err(|_| ()),
), ),
) )
}); });
// non ws request // non ws request
let res = srv.get("/index.html").send().await.unwrap(); let res = srv.block_on(srv.get("/index.html").send()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.status(), StatusCode::BAD_REQUEST);
// not found // not found
assert!(srv.ws_at("/test").await.is_err()); assert!(srv.ws_at("/test").is_err());
// client service // client service
let mut framed = srv.ws_at("/index.html").await.unwrap(); let framed = srv.ws_at("/index.html").unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await; assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Binary("text".into())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await; assert_eq!(
assert_eq!( item,
item.unwrap().unwrap(), Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) );
);
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let (item, mut framed) = framed.into_future().await; .block_on(framed.send(ws::Message::Ping("text".into())))
assert_eq!( .unwrap();
item.unwrap().unwrap(), let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
ws::Frame::Pong("text".to_string().into()) assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
);
framed let framed = srv
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.await .unwrap();
.unwrap();
let (item, _) = framed.into_future().await; let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Close(Some(ws::CloseCode::Normal.into())) Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
); );
})
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-http" name = "actix-http"
version = "0.3.0-alpha.1" version = "0.2.11"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http primitives" description = "Actix http primitives"
readme = "README.md" readme = "README.md"
@@ -16,7 +16,7 @@ edition = "2018"
workspace = ".." workspace = ".."
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "fail", "brotli", "flate2-zlib", "secure-cookies"] features = ["ssl", "fail", "brotli", "flate2-zlib", "secure-cookies"]
[lib] [lib]
name = "actix_http" name = "actix_http"
@@ -26,10 +26,10 @@ path = "src/lib.rs"
default = [] default = []
# openssl # openssl
openssl = ["open-ssl", "actix-connect/openssl", "tokio-openssl"] ssl = ["openssl", "actix-connect/ssl"]
# rustls support # rustls support
# rustls = ["rust-tls", "webpki-roots", "actix-connect/rustls"] rust-tls = ["rustls", "webpki-roots", "actix-connect/rust-tls"]
# brotli encoding, requires c compiler # brotli encoding, requires c compiler
brotli = ["brotli2"] brotli = ["brotli2"]
@@ -47,24 +47,23 @@ fail = ["failure"]
secure-cookies = ["ring"] secure-cookies = ["ring"]
[dependencies] [dependencies]
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
actix-codec = "0.2.0-alpha.1" actix-codec = "0.1.2"
actix-connect = "1.0.0-alpha.1" actix-connect = "0.2.4"
actix-utils = "0.5.0-alpha.1" actix-utils = "0.4.4"
actix-server-config = "0.3.0-alpha.1" actix-server-config = "0.1.2"
actix-threadpool = "0.2.0-alpha.1" actix-threadpool = "0.1.1"
base64 = "0.10" base64 = "0.10"
bitflags = "1.0" bitflags = "1.0"
bytes = "0.4" bytes = "0.4"
copyless = "0.1.4" copyless = "0.1.4"
chrono = "0.4.6"
derive_more = "0.15.0" derive_more = "0.15.0"
either = "1.5.2" either = "1.5.2"
encoding_rs = "0.8" encoding_rs = "0.8"
futures = "0.3.1" futures = "0.1.25"
hashbrown = "0.6.3" hashbrown = "0.6.3"
h2 = "0.2.0-alpha.3" h2 = "0.1.16"
http = "0.1.17" http = "0.1.17"
httparse = "1.3" httparse = "1.3"
indexmap = "1.2" indexmap = "1.2"
@@ -73,7 +72,6 @@ language-tags = "0.2"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
percent-encoding = "2.1" percent-encoding = "2.1"
pin-project = "0.4.5"
rand = "0.7" rand = "0.7"
regex = "1.0" regex = "1.0"
serde = "1.0" serde = "1.0"
@@ -82,16 +80,13 @@ sha1 = "0.6"
slab = "0.4" slab = "0.4"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
time = "0.1.42" time = "0.1.42"
tokio-tcp = "0.1.3"
tokio = "=0.2.0-alpha.6" tokio-timer = "0.2.8"
tokio-io = "=0.2.0-alpha.6" tokio-current-thread = "0.1"
tokio-net = "=0.2.0-alpha.6" trust-dns-resolver = { version="0.11.1", default-features = false }
tokio-timer = "0.3.0-alpha.6"
tokio-executor = "=0.2.0-alpha.6"
trust-dns-resolver = { version="0.18.0-alpha.1", default-features = false }
# for secure cookie # for secure cookie
ring = { version = "0.16.9", optional = true } ring = { version = "0.14.6", optional = true }
# compression # compression
brotli2 = { version="0.3.2", optional = true } brotli2 = { version="0.3.2", optional = true }
@@ -99,17 +94,17 @@ flate2 = { version="1.0.7", optional = true, default-features = false }
# optional deps # optional deps
failure = { version = "0.1.5", optional = true } failure = { version = "0.1.5", optional = true }
open-ssl = { version="0.10", package="openssl", optional = true } openssl = { version="0.10", optional = true }
tokio-openssl = { version = "0.4.0-alpha.6", optional = true } rustls = { version = "0.15.2", optional = true }
webpki-roots = { version = "0.16", optional = true }
rust-tls = { version = "0.16.0", package="rustls", optional = true } chrono = "0.4.6"
webpki-roots = { version = "0.18", optional = true }
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.2"
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] }
actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] } actix-connect = { version = "0.2.0", features=["ssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } actix-http-test = { version = "0.2.4", features=["ssl"] }
env_logger = "0.6" env_logger = "0.6"
serde_derive = "1.0" serde_derive = "1.0"
open-ssl = { version="0.10", package="openssl" } openssl = { version="0.10" }
tokio-tcp = "0.1"

View File

@@ -1,9 +1,9 @@
use std::{env, io}; use std::{env, io};
use actix_http::{Error, HttpService, Request, Response}; use actix_http::{error::PayloadError, HttpService, Request, Response};
use actix_server::Server; use actix_server::Server;
use bytes::BytesMut; use bytes::BytesMut;
use futures::StreamExt; use futures::{Future, Stream};
use http::header::HeaderValue; use http::header::HeaderValue;
use log::info; use log::info;
@@ -17,22 +17,20 @@ fn main() -> io::Result<()> {
.client_timeout(1000) .client_timeout(1000)
.client_disconnect(1000) .client_disconnect(1000)
.finish(|mut req: Request| { .finish(|mut req: Request| {
async move { req.take_payload()
let mut body = BytesMut::new(); .fold(BytesMut::new(), move |mut body, chunk| {
while let Some(item) = req.payload().next().await { body.extend_from_slice(&chunk);
body.extend_from_slice(&item?); Ok::<_, PayloadError>(body)
} })
.and_then(|bytes| {
info!("request body: {:?}", body); info!("request body: {:?}", bytes);
Ok::<_, Error>( let mut res = Response::Ok();
Response::Ok() res.header(
.header( "x-head",
"x-head", HeaderValue::from_static("dummy value!"),
HeaderValue::from_static("dummy value!"), );
) Ok(res.body(bytes))
.body(body), })
)
}
}) })
})? })?
.run() .run()

View File

@@ -1,22 +1,25 @@
use std::{env, io}; use std::{env, io};
use actix_http::http::HeaderValue; use actix_http::http::HeaderValue;
use actix_http::{Error, HttpService, Request, Response}; use actix_http::{error::PayloadError, Error, HttpService, Request, Response};
use actix_server::Server; use actix_server::Server;
use bytes::BytesMut; use bytes::BytesMut;
use futures::StreamExt; use futures::{Future, Stream};
use log::info; use log::info;
async fn handle_request(mut req: Request) -> Result<Response, Error> { fn handle_request(mut req: Request) -> impl Future<Item = Response, Error = Error> {
let mut body = BytesMut::new(); req.take_payload()
while let Some(item) = req.payload().next().await { .fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&item?) body.extend_from_slice(&chunk);
} Ok::<_, PayloadError>(body)
})
info!("request body: {:?}", body); .from_err()
Ok(Response::Ok() .and_then(|bytes| {
.header("x-head", HeaderValue::from_static("dummy value!")) info!("request body: {:?}", bytes);
.body(body)) let mut res = Response::Ok();
res.header("x-head", HeaderValue::from_static("dummy value!"));
Ok(res.body(bytes))
})
} }
fn main() -> io::Result<()> { fn main() -> io::Result<()> {
@@ -25,7 +28,7 @@ fn main() -> io::Result<()> {
Server::build() Server::build()
.bind("echo", "127.0.0.1:8080", || { .bind("echo", "127.0.0.1:8080", || {
HttpService::build().finish(handle_request) HttpService::build().finish(|_req: Request| handle_request(_req))
})? })?
.run() .run()
} }

View File

@@ -1,11 +1,8 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem}; use std::{fmt, mem};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::Stream; use futures::{Async, Poll, Stream};
use pin_project::{pin_project, project};
use crate::error::Error; use crate::error::Error;
@@ -35,7 +32,7 @@ impl BodySize {
pub trait MessageBody { pub trait MessageBody {
fn size(&self) -> BodySize; fn size(&self) -> BodySize;
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>>; fn poll_next(&mut self) -> Poll<Option<Bytes>, Error>;
} }
impl MessageBody for () { impl MessageBody for () {
@@ -43,8 +40,8 @@ impl MessageBody for () {
BodySize::Empty BodySize::Empty
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
Poll::Ready(None) Ok(Async::Ready(None))
} }
} }
@@ -53,12 +50,11 @@ impl<T: MessageBody> MessageBody for Box<T> {
self.as_ref().size() self.as_ref().size()
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
self.as_mut().poll_next(cx) self.as_mut().poll_next()
} }
} }
#[pin_project]
pub enum ResponseBody<B> { pub enum ResponseBody<B> {
Body(B), Body(B),
Other(Body), Other(Body),
@@ -97,24 +93,20 @@ impl<B: MessageBody> MessageBody for ResponseBody<B> {
} }
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
match self { match self {
ResponseBody::Body(ref mut body) => body.poll_next(cx), ResponseBody::Body(ref mut body) => body.poll_next(),
ResponseBody::Other(ref mut body) => body.poll_next(cx), ResponseBody::Other(ref mut body) => body.poll_next(),
} }
} }
} }
impl<B: MessageBody> Stream for ResponseBody<B> { impl<B: MessageBody> Stream for ResponseBody<B> {
type Item = Result<Bytes, Error>; type Item = Bytes;
type Error = Error;
#[project] fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { self.poll_next()
#[project]
match self.project() {
ResponseBody::Body(ref mut body) => body.poll_next(cx),
ResponseBody::Other(ref mut body) => body.poll_next(cx),
}
} }
} }
@@ -152,19 +144,19 @@ impl MessageBody for Body {
} }
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
match self { match self {
Body::None => Poll::Ready(None), Body::None => Ok(Async::Ready(None)),
Body::Empty => Poll::Ready(None), Body::Empty => Ok(Async::Ready(None)),
Body::Bytes(ref mut bin) => { Body::Bytes(ref mut bin) => {
let len = bin.len(); let len = bin.len();
if len == 0 { if len == 0 {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new())))) Ok(Async::Ready(Some(mem::replace(bin, Bytes::new()))))
} }
} }
Body::Message(ref mut body) => body.poll_next(cx), Body::Message(ref mut body) => body.poll_next(),
} }
} }
} }
@@ -250,7 +242,7 @@ impl From<serde_json::Value> for Body {
impl<S> From<SizedStream<S>> for Body impl<S> From<SizedStream<S>> for Body
where where
S: Stream<Item = Result<Bytes, Error>> + 'static, S: Stream<Item = Bytes, Error = Error> + 'static,
{ {
fn from(s: SizedStream<S>) -> Body { fn from(s: SizedStream<S>) -> Body {
Body::from_message(s) Body::from_message(s)
@@ -259,7 +251,7 @@ where
impl<S, E> From<BodyStream<S, E>> for Body impl<S, E> From<BodyStream<S, E>> for Body
where where
S: Stream<Item = Result<Bytes, E>> + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
fn from(s: BodyStream<S, E>) -> Body { fn from(s: BodyStream<S, E>) -> Body {
@@ -272,11 +264,11 @@ impl MessageBody for Bytes {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(mem::replace(self, Bytes::new())))) Ok(Async::Ready(Some(mem::replace(self, Bytes::new()))))
} }
} }
} }
@@ -286,11 +278,13 @@ impl MessageBody for BytesMut {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze()))) Ok(Async::Ready(Some(
mem::replace(self, BytesMut::new()).freeze(),
)))
} }
} }
} }
@@ -300,11 +294,11 @@ impl MessageBody for &'static str {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from_static( Ok(Async::Ready(Some(Bytes::from_static(
mem::replace(self, "").as_ref(), mem::replace(self, "").as_ref(),
)))) ))))
} }
@@ -316,11 +310,13 @@ impl MessageBody for &'static [u8] {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b""))))) Ok(Async::Ready(Some(Bytes::from_static(mem::replace(
self, b"",
)))))
} }
} }
} }
@@ -330,11 +326,14 @@ impl MessageBody for Vec<u8> {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new()))))) Ok(Async::Ready(Some(Bytes::from(mem::replace(
self,
Vec::new(),
)))))
} }
} }
} }
@@ -344,11 +343,11 @@ impl MessageBody for String {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from( Ok(Async::Ready(Some(Bytes::from(
mem::replace(self, String::new()).into_bytes(), mem::replace(self, String::new()).into_bytes(),
)))) ))))
} }
@@ -357,16 +356,14 @@ impl MessageBody for String {
/// Type represent streaming body. /// Type represent streaming body.
/// Response does not contain `content-length` header and appropriate transfer encoding is used. /// Response does not contain `content-length` header and appropriate transfer encoding is used.
#[pin_project]
pub struct BodyStream<S, E> { pub struct BodyStream<S, E> {
#[pin]
stream: S, stream: S,
_t: PhantomData<E>, _t: PhantomData<E>,
} }
impl<S, E> BodyStream<S, E> impl<S, E> BodyStream<S, E>
where where
S: Stream<Item = Result<Bytes, E>>, S: Stream<Item = Bytes, Error = E>,
E: Into<Error>, E: Into<Error>,
{ {
pub fn new(stream: S) -> Self { pub fn new(stream: S) -> Self {
@@ -379,34 +376,28 @@ where
impl<S, E> MessageBody for BodyStream<S, E> impl<S, E> MessageBody for BodyStream<S, E>
where where
S: Stream<Item = Result<Bytes, E>>, S: Stream<Item = Bytes, Error = E>,
E: Into<Error>, E: Into<Error>,
{ {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
BodySize::Stream BodySize::Stream
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
unsafe { Pin::new_unchecked(self) } self.stream.poll().map_err(std::convert::Into::into)
.project()
.stream
.poll_next(cx)
.map(|res| res.map(|res| res.map_err(std::convert::Into::into)))
} }
} }
/// Type represent streaming body. This body implementation should be used /// Type represent streaming body. This body implementation should be used
/// if total size of stream is known. Data get sent as is without using transfer encoding. /// if total size of stream is known. Data get sent as is without using transfer encoding.
#[pin_project]
pub struct SizedStream<S> { pub struct SizedStream<S> {
size: u64, size: u64,
#[pin]
stream: S, stream: S,
} }
impl<S> SizedStream<S> impl<S> SizedStream<S>
where where
S: Stream<Item = Result<Bytes, Error>>, S: Stream<Item = Bytes, Error = Error>,
{ {
pub fn new(size: u64, stream: S) -> Self { pub fn new(size: u64, stream: S) -> Self {
SizedStream { size, stream } SizedStream { size, stream }
@@ -415,25 +406,20 @@ where
impl<S> MessageBody for SizedStream<S> impl<S> MessageBody for SizedStream<S>
where where
S: Stream<Item = Result<Bytes, Error>>, S: Stream<Item = Bytes, Error = Error>,
{ {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
BodySize::Sized64(self.size) BodySize::Sized64(self.size)
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
unsafe { Pin::new_unchecked(self) } self.stream.poll()
.project()
.stream
.poll_next(cx)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use actix_http_test::block_on;
use futures::future::{lazy, poll_fn};
impl Body { impl Body {
pub(crate) fn get_ref(&self) -> &[u8] { pub(crate) fn get_ref(&self) -> &[u8] {
@@ -461,8 +447,8 @@ mod tests {
assert_eq!("test".size(), BodySize::Sized(4)); assert_eq!("test".size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
block_on(poll_fn(|cx| "test".poll_next(cx))).unwrap().ok(), "test".poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
@@ -478,10 +464,8 @@ mod tests {
assert_eq!((&b"test"[..]).size(), BodySize::Sized(4)); assert_eq!((&b"test"[..]).size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
block_on(poll_fn(|cx| (&b"test"[..]).poll_next(cx))) (&b"test"[..]).poll_next().unwrap(),
.unwrap() Async::Ready(Some(Bytes::from("test")))
.ok(),
Some(Bytes::from("test"))
); );
} }
@@ -492,10 +476,8 @@ mod tests {
assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); assert_eq!(Vec::from("test").size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
block_on(poll_fn(|cx| Vec::from("test").poll_next(cx))) Vec::from("test").poll_next().unwrap(),
.unwrap() Async::Ready(Some(Bytes::from("test")))
.ok(),
Some(Bytes::from("test"))
); );
} }
@@ -507,8 +489,8 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(), b.poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
@@ -520,8 +502,8 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(), b.poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
@@ -535,22 +517,22 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(), b.poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
#[test] #[test]
fn test_unit() { fn test_unit() {
assert_eq!(().size(), BodySize::Empty); assert_eq!(().size(), BodySize::Empty);
assert!(block_on(poll_fn(|cx| ().poll_next(cx))).is_none()); assert_eq!(().poll_next().unwrap(), Async::Ready(None));
} }
#[test] #[test]
fn test_box() { fn test_box() {
let mut val = Box::new(()); let mut val = Box::new(());
assert_eq!(val.size(), BodySize::Empty); assert_eq!(val.size(), BodySize::Empty);
assert!(block_on(poll_fn(|cx| val.poll_next(cx))).is_none()); assert_eq!(val.poll_next().unwrap(), Async::Ready(None));
} }
#[test] #[test]

View File

@@ -4,7 +4,7 @@ use std::rc::Rc;
use actix_codec::Framed; use actix_codec::Framed;
use actix_server_config::ServerConfig as SrvConfig; use actix_server_config::ServerConfig as SrvConfig;
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_service::{IntoNewService, NewService, Service};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::config::{KeepAlive, ServiceConfig}; use crate::config::{KeepAlive, ServiceConfig};
@@ -32,10 +32,9 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
<S::Service as Service>::Future: 'static,
{ {
/// Create instance of `ServiceConfigBuilder` /// Create instance of `ServiceConfigBuilder`
pub fn new() -> Self { pub fn new() -> Self {
@@ -53,22 +52,19 @@ where
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U> impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
<S::Service as Service>::Future: 'static, X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static, U: NewService<
U: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
>, >,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
/// Set server keep-alive setting. /// Set server keep-alive setting.
/// ///
@@ -112,17 +108,16 @@ where
/// request will be forwarded to main service. /// request will be forwarded to main service.
pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U> pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U>
where where
F: IntoServiceFactory<X1>, F: IntoNewService<X1>,
X1: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>, X1: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{ {
HttpServiceBuilder { HttpServiceBuilder {
keep_alive: self.keep_alive, keep_alive: self.keep_alive,
client_timeout: self.client_timeout, client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect, client_disconnect: self.client_disconnect,
expect: expect.into_factory(), expect: expect.into_new_service(),
upgrade: self.upgrade, upgrade: self.upgrade,
on_connect: self.on_connect, on_connect: self.on_connect,
_t: PhantomData, _t: PhantomData,
@@ -135,22 +130,21 @@ where
/// and this service get called with original request and framed object. /// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1> pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
where where
F: IntoServiceFactory<U1>, F: IntoNewService<U1>,
U1: ServiceFactory< U1: NewService<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
>, >,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{ {
HttpServiceBuilder { HttpServiceBuilder {
keep_alive: self.keep_alive, keep_alive: self.keep_alive,
client_timeout: self.client_timeout, client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect, client_disconnect: self.client_disconnect,
expect: self.expect, expect: self.expect,
upgrade: Some(upgrade.into_factory()), upgrade: Some(upgrade.into_new_service()),
on_connect: self.on_connect, on_connect: self.on_connect,
_t: PhantomData, _t: PhantomData,
} }
@@ -172,8 +166,8 @@ where
/// Finish service configuration and create *http service* for HTTP/1 protocol. /// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U> pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
where where
B: MessageBody, B: MessageBody + 'static,
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -183,7 +177,7 @@ where
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
); );
H1Service::with_config(cfg, service.into_factory()) H1Service::with_config(cfg, service.into_new_service())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)
@@ -193,10 +187,10 @@ where
pub fn h2<F, P, B>(self, service: F) -> H2Service<T, P, S, B> pub fn h2<F, P, B>(self, service: F) -> H2Service<T, P, S, B>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
{ {
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
@@ -204,17 +198,18 @@ where
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
); );
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) H2Service::with_config(cfg, service.into_new_service())
.on_connect(self.on_connect)
} }
/// Finish service configuration and create `HttpService` instance. /// Finish service configuration and create `HttpService` instance.
pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X, U> pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X, U>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
{ {
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
@@ -222,7 +217,7 @@ where
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
); );
HttpService::with_config(cfg, service.into_factory()) HttpService::with_config(cfg, service.into_new_service())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)

View File

@@ -1,12 +1,10 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, time}; use std::{fmt, io, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready}; use futures::future::{err, Either, Future, FutureResult};
use futures::Poll;
use h2::client::SendRequest; use h2::client::SendRequest;
use pin_project::{pin_project, project};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::h1::ClientCodec; use crate::h1::ClientCodec;
@@ -23,8 +21,8 @@ pub(crate) enum ConnectionType<Io> {
} }
pub trait Connection { pub trait Connection {
type Io: AsyncRead + AsyncWrite + Unpin; type Io: AsyncRead + AsyncWrite;
type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>; type Future: Future<Item = (ResponseHead, Payload), Error = SendRequestError>;
fn protocol(&self) -> Protocol; fn protocol(&self) -> Protocol;
@@ -36,7 +34,8 @@ pub trait Connection {
) -> Self::Future; ) -> Self::Future;
type TunnelFuture: Future< type TunnelFuture: Future<
Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
@@ -72,7 +71,7 @@ where
} }
} }
impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> { impl<T: AsyncRead + AsyncWrite> IoConnection<T> {
pub(crate) fn new( pub(crate) fn new(
io: ConnectionType<T>, io: ConnectionType<T>,
created: time::Instant, created: time::Instant,
@@ -92,11 +91,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
impl<T> Connection for IoConnection<T> impl<T> Connection for IoConnection<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
type Io = T; type Io = T;
type Future = type Future =
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self.io { match self.io {
@@ -112,30 +111,38 @@ where
body: B, body: B,
) -> Self::Future { ) -> Self::Future {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => Box::new(h1proto::send_request(
h1proto::send_request(io, head.into(), body, self.created, self.pool) io,
.boxed_local() head.into(),
} body,
ConnectionType::H2(io) => { self.created,
h2proto::send_request(io, head.into(), body, self.created, self.pool) self.pool,
.boxed_local() )),
} ConnectionType::H2(io) => Box::new(h2proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
} }
} }
type TunnelFuture = Either< type TunnelFuture = Either<
LocalBoxFuture< Box<
'static, dyn Future<
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>,
>, >,
Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>, FutureResult<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => {
Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) Either::A(Box::new(h1proto::open_tunnel(io, head.into())))
} }
ConnectionType::H2(io) => { ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
@@ -145,7 +152,7 @@ where
None, None,
)); ));
} }
Either::Right(err(SendRequestError::TunnelNotSupported)) Either::B(err(SendRequestError::TunnelNotSupported))
} }
} }
} }
@@ -159,12 +166,12 @@ pub(crate) enum EitherConnection<A, B> {
impl<A, B> Connection for EitherConnection<A, B> impl<A, B> Connection for EitherConnection<A, B>
where where
A: AsyncRead + AsyncWrite + Unpin + 'static, A: AsyncRead + AsyncWrite + 'static,
B: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + 'static,
{ {
type Io = EitherIo<A, B>; type Io = EitherIo<A, B>;
type Future = type Future =
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self { match self {
@@ -184,30 +191,44 @@ where
} }
} }
type TunnelFuture = LocalBoxFuture< type TunnelFuture = Box<
'static, dyn Future<
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
match self { match self {
EitherConnection::A(con) => con EitherConnection::A(con) => Box::new(
.open_tunnel(head) con.open_tunnel(head)
.map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) .map(|(head, framed)| (head, framed.map_io(EitherIo::A))),
.boxed_local(), ),
EitherConnection::B(con) => con EitherConnection::B(con) => Box::new(
.open_tunnel(head) con.open_tunnel(head)
.map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) .map(|(head, framed)| (head, framed.map_io(EitherIo::B))),
.boxed_local(), ),
} }
} }
} }
#[pin_project]
pub enum EitherIo<A, B> { pub enum EitherIo<A, B> {
A(#[pin] A), A(A),
B(#[pin] B), B(B),
}
impl<A, B> io::Read for EitherIo<A, B>
where
A: io::Read,
B: io::Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.read(buf),
EitherIo::B(ref mut val) => val.read(buf),
}
}
} }
impl<A, B> AsyncRead for EitherIo<A, B> impl<A, B> AsyncRead for EitherIo<A, B>
@@ -215,19 +236,6 @@ where
A: AsyncRead, A: AsyncRead,
B: AsyncRead, B: AsyncRead,
{ {
#[project]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_read(cx, buf),
EitherIo::B(val) => val.poll_read(cx, buf),
}
}
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self { match self {
EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf),
@@ -236,58 +244,45 @@ where
} }
} }
impl<A, B> io::Write for EitherIo<A, B>
where
A: io::Write,
B: io::Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.write(buf),
EitherIo::B(ref mut val) => val.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
EitherIo::A(ref mut val) => val.flush(),
EitherIo::B(ref mut val) => val.flush(),
}
}
}
impl<A, B> AsyncWrite for EitherIo<A, B> impl<A, B> AsyncWrite for EitherIo<A, B>
where where
A: AsyncWrite, A: AsyncWrite,
B: AsyncWrite, B: AsyncWrite,
{ {
#[project] fn shutdown(&mut self) -> Poll<(), io::Error> {
fn poll_write( match self {
self: Pin<&mut Self>, EitherIo::A(ref mut val) => val.shutdown(),
cx: &mut Context, EitherIo::B(ref mut val) => val.shutdown(),
buf: &[u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_write(cx, buf),
EitherIo::B(val) => val.poll_write(cx, buf),
} }
} }
#[project] fn write_buf<U: Buf>(&mut self, buf: &mut U) -> Poll<usize, io::Error>
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_flush(cx),
EitherIo::B(val) => val.poll_flush(cx),
}
}
#[project]
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_shutdown(cx),
EitherIo::B(val) => val.poll_shutdown(cx),
}
}
#[project]
fn poll_write_buf<U: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut U,
) -> Poll<Result<usize, io::Error>>
where where
Self: Sized, Self: Sized,
{ {
#[project] match self {
match self.project() { EitherIo::A(ref mut val) => val.write_buf(buf),
EitherIo::A(val) => val.poll_write_buf(cx, buf), EitherIo::B(ref mut val) => val.write_buf(buf),
EitherIo::B(val) => val.poll_write_buf(cx, buf),
} }
} }
} }

View File

@@ -1,41 +1,37 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_connect::{ use actix_connect::{
default_connector, Connect as TcpConnect, Connection as TcpConnection, default_connector, Connect as TcpConnect, Connection as TcpConnection,
}; };
use actix_service::{apply_fn, Service}; use actix_service::{apply_fn, Service, ServiceExt};
use actix_utils::timeout::{TimeoutError, TimeoutService}; use actix_utils::timeout::{TimeoutError, TimeoutService};
use futures::future::Ready;
use http::Uri; use http::Uri;
use tokio_net::tcp::TcpStream; use tokio_tcp::TcpStream;
use super::connection::Connection; use super::connection::Connection;
use super::error::ConnectError; use super::error::ConnectError;
use super::pool::{ConnectionPool, Protocol}; use super::pool::{ConnectionPool, Protocol};
use super::Connect; use super::Connect;
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
use open_ssl::ssl::SslConnector as OpensslConnector; use openssl::ssl::SslConnector as OpensslConnector;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use rust_tls::ClientConfig; use rustls::ClientConfig;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use std::sync::Arc; use std::sync::Arc;
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "ssl", feature = "rust-tls"))]
enum SslConnector { enum SslConnector {
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
Openssl(OpensslConnector), Openssl(OpensslConnector),
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
Rustls(Arc<ClientConfig>), Rustls(Arc<ClientConfig>),
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
type SslConnector = (); type SslConnector = ();
/// Manages http client network connectivity /// Manages http client network connectivity
@@ -62,8 +58,8 @@ pub struct Connector<T, U> {
_t: PhantomData<U>, _t: PhantomData<U>,
} }
trait Io: AsyncRead + AsyncWrite + Unpin {} trait Io: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {} impl<T: AsyncRead + AsyncWrite> Io for T {}
impl Connector<(), ()> { impl Connector<(), ()> {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
@@ -76,9 +72,9 @@ impl Connector<(), ()> {
TcpStream, TcpStream,
> { > {
let ssl = { let ssl = {
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
{ {
use open_ssl::ssl::SslMethod; use openssl::ssl::SslMethod;
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
let _ = ssl let _ = ssl
@@ -86,7 +82,7 @@ impl Connector<(), ()> {
.map_err(|e| error!("Can not set alpn protocol: {:?}", e)); .map_err(|e| error!("Can not set alpn protocol: {:?}", e));
SslConnector::Openssl(ssl.build()) SslConnector::Openssl(ssl.build())
} }
#[cfg(all(not(feature = "openssl"), feature = "rustls"))] #[cfg(all(not(feature = "ssl"), feature = "rust-tls"))]
{ {
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
@@ -96,7 +92,7 @@ impl Connector<(), ()> {
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
SslConnector::Rustls(Arc::new(config)) SslConnector::Rustls(Arc::new(config))
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
{} {}
}; };
@@ -117,7 +113,7 @@ impl<T, U> Connector<T, U> {
/// Use custom connector. /// Use custom connector.
pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1> pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1>
where where
U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U1: AsyncRead + AsyncWrite + fmt::Debug,
T1: Service< T1: Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U1>, Response = TcpConnection<Uri, U1>,
@@ -139,7 +135,7 @@ impl<T, U> Connector<T, U> {
impl<T, U> Connector<T, U> impl<T, U> Connector<T, U>
where where
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: AsyncRead + AsyncWrite + fmt::Debug + 'static,
T: Service< T: Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U>, Response = TcpConnection<Uri, U>,
@@ -154,14 +150,14 @@ where
self self
} }
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
/// Use custom `SslConnector` instance. /// Use custom `SslConnector` instance.
pub fn ssl(mut self, connector: OpensslConnector) -> Self { pub fn ssl(mut self, connector: OpensslConnector) -> Self {
self.ssl = SslConnector::Openssl(connector); self.ssl = SslConnector::Openssl(connector);
self self
} }
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self { pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self {
self.ssl = SslConnector::Rustls(connector); self.ssl = SslConnector::Rustls(connector);
self self
@@ -216,8 +212,8 @@ where
pub fn finish( pub fn finish(
self, self,
) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError> ) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
+ Clone { + Clone {
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
{ {
let connector = TimeoutService::new( let connector = TimeoutService::new(
self.timeout, self.timeout,
@@ -242,32 +238,32 @@ where
), ),
} }
} }
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "ssl", feature = "rust-tls"))]
{ {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
use actix_connect::ssl::OpensslConnector; use actix_connect::ssl::OpensslConnector;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use actix_connect::ssl::RustlsConnector; use actix_connect::ssl::RustlsConnector;
use actix_service::{boxed::service, pipeline}; use actix_service::boxed::service;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use rust_tls::Session; use rustls::Session;
let ssl_service = TimeoutService::new( let ssl_service = TimeoutService::new(
self.timeout, self.timeout,
pipeline( apply_fn(self.connector.clone(), |msg: Connect, srv| {
apply_fn(self.connector.clone(), |msg: Connect, srv| { srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) })
}) .map_err(ConnectError::from)
.map_err(ConnectError::from),
)
.and_then(match self.ssl { .and_then(match self.ssl {
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
SslConnector::Openssl(ssl) => service( SslConnector::Openssl(ssl) => service(
OpensslConnector::service(ssl) OpensslConnector::service(ssl)
.map_err(ConnectError::from)
.map(|stream| { .map(|stream| {
let sock = stream.into_parts().0; let sock = stream.into_parts().0;
let h2 = sock let h2 = sock
.get_ref()
.ssl() .ssl()
.selected_alpn_protocol() .selected_alpn_protocol()
.map(|protos| protos.windows(2).any(|w| w == H2)) .map(|protos| protos.windows(2).any(|w| w == H2))
@@ -277,10 +273,9 @@ where
} else { } else {
(Box::new(sock) as Box<dyn Io>, Protocol::Http1) (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
} }
}) }),
.map_err(ConnectError::from),
), ),
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
SslConnector::Rustls(ssl) => service( SslConnector::Rustls(ssl) => service(
RustlsConnector::service(ssl) RustlsConnector::service(ssl)
.map_err(ConnectError::from) .map_err(ConnectError::from)
@@ -308,7 +303,7 @@ where
let tcp_service = TimeoutService::new( let tcp_service = TimeoutService::new(
self.timeout, self.timeout,
apply_fn(self.connector, |msg: Connect, srv| { apply_fn(self.connector.clone(), |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
}) })
.map_err(ConnectError::from) .map_err(ConnectError::from)
@@ -339,20 +334,19 @@ where
} }
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
mod connect_impl { mod connect_impl {
use std::task::{Context, Poll}; use futures::future::{err, Either, FutureResult};
use futures::Poll;
use futures::future::{err, Either, Ready};
use futures::ready;
use super::*; use super::*;
use crate::client::connection::IoConnection; use crate::client::connection::IoConnection;
pub(crate) struct InnerConnector<T, Io> pub(crate) struct InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
pub(crate) tcp_pool: ConnectionPool<T, Io>, pub(crate) tcp_pool: ConnectionPool<T, Io>,
@@ -360,8 +354,9 @@ mod connect_impl {
impl<T, Io> Clone for InnerConnector<T, Io> impl<T, Io> Clone for InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@@ -373,8 +368,9 @@ mod connect_impl {
impl<T, Io> Service for InnerConnector<T, Io> impl<T, Io> Service for InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
@@ -382,38 +378,38 @@ mod connect_impl {
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = Either<
<ConnectionPool<T, Io> as Service>::Future, <ConnectionPool<T, Io> as Service>::Future,
Ready<Result<IoConnection<Io>, ConnectError>>, FutureResult<IoConnection<Io>, ConnectError>,
>; >;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.tcp_pool.poll_ready(cx) self.tcp_pool.poll_ready()
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => { Some("https") | Some("wss") => {
Either::Right(err(ConnectError::SslIsNotSupported)) Either::B(err(ConnectError::SslIsNotSupported))
} }
_ => Either::Left(self.tcp_pool.call(req)), _ => Either::A(self.tcp_pool.call(req)),
} }
} }
} }
} }
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "ssl", feature = "rust-tls"))]
mod connect_impl { mod connect_impl {
use std::marker::PhantomData; use std::marker::PhantomData;
use futures::future::Either; use futures::future::{Either, FutureResult};
use futures::ready; use futures::{Async, Future, Poll};
use super::*; use super::*;
use crate::client::connection::EitherConnection; use crate::client::connection::EitherConnection;
pub(crate) struct InnerConnector<T1, T2, Io1, Io2> pub(crate) struct InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>, T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>, T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>,
{ {
@@ -423,11 +419,13 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2> impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@@ -440,47 +438,53 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2> impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
type Response = EitherConnection<Io1, Io2>; type Response = EitherConnection<Io1, Io2>;
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = Either<
InnerConnectorResponseA<T1, Io1, Io2>, FutureResult<Self::Response, Self::Error>,
InnerConnectorResponseB<T2, Io1, Io2>, Either<
InnerConnectorResponseA<T1, Io1, Io2>,
InnerConnectorResponseB<T2, Io1, Io2>,
>,
>; >;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.tcp_pool.poll_ready(cx) self.tcp_pool.poll_ready()
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { Some("https") | Some("wss") => {
fut: self.ssl_pool.call(req), Either::B(Either::B(InnerConnectorResponseB {
_t: PhantomData, fut: self.ssl_pool.call(req),
}), _t: PhantomData,
_ => Either::Left(InnerConnectorResponseA { }))
}
_ => Either::B(Either::A(InnerConnectorResponseA {
fut: self.tcp_pool.call(req), fut: self.tcp_pool.call(req),
_t: PhantomData, _t: PhantomData,
}), })),
} }
} }
} }
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseA<T, Io1, Io2> pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
#[pin]
fut: <ConnectionPool<T, Io1> as Service>::Future, fut: <ConnectionPool<T, Io1> as Service>::Future,
_t: PhantomData<Io2>, _t: PhantomData<Io2>,
} }
@@ -488,28 +492,29 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
where where
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
{ {
type Output = Result<EitherConnection<Io1, Io2>, ConnectError>; type Item = EitherConnection<Io1, Io2>;
type Error = ConnectError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Poll::Ready( match self.fut.poll()? {
ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) Async::NotReady => Ok(Async::NotReady),
.map(|res| EitherConnection::A(res)), Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))),
) }
} }
} }
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseB<T, Io1, Io2> pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
where where
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
#[pin]
fut: <ConnectionPool<T, Io2> as Service>::Future, fut: <ConnectionPool<T, Io2> as Service>::Future,
_t: PhantomData<Io1>, _t: PhantomData<Io1>,
} }
@@ -517,17 +522,19 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>
where where
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
{ {
type Output = Result<EitherConnection<Io1, Io2>, ConnectError>; type Item = EitherConnection<Io1, Io2>;
type Error = ConnectError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Poll::Ready( match self.fut.poll()? {
ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) Async::NotReady => Ok(Async::NotReady),
.map(|res| EitherConnection::B(res)), Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))),
) }
} }
} }
} }

View File

@@ -3,8 +3,8 @@ use std::io;
use derive_more::{Display, From}; use derive_more::{Display, From};
use trust_dns_resolver::error::ResolveError; use trust_dns_resolver::error::ResolveError;
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
use open_ssl::ssl::{Error as SslError, HandshakeError}; use openssl::ssl::{Error as SslError, HandshakeError};
use crate::error::{Error, ParseError, ResponseError}; use crate::error::{Error, ParseError, ResponseError};
use crate::http::Error as HttpError; use crate::http::Error as HttpError;
@@ -18,7 +18,7 @@ pub enum ConnectError {
SslIsNotSupported, SslIsNotSupported,
/// SSL error /// SSL error
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
#[display(fmt = "{}", _0)] #[display(fmt = "{}", _0)]
SslError(SslError), SslError(SslError),
@@ -63,7 +63,7 @@ impl From<actix_connect::ConnectError> for ConnectError {
} }
} }
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
impl<T> From<HandshakeError<T>> for ConnectError { impl<T> From<HandshakeError<T>> for ConnectError {
fn from(err: HandshakeError<T>) -> ConnectError { fn from(err: HandshakeError<T>) -> ConnectError {
match err { match err {

View File

@@ -1,13 +1,10 @@
use std::future::Future;
use std::io::Write; use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, time}; use std::{io, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures::future::{ok, poll_fn, Either}; use futures::future::{ok, Either};
use futures::{Sink, SinkExt, Stream, StreamExt}; use futures::{Async, Future, Poll, Sink, Stream};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
@@ -21,15 +18,15 @@ use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired; use super::pool::Acquired;
use crate::body::{BodySize, MessageBody}; use crate::body::{BodySize, MessageBody};
pub(crate) async fn send_request<T, B>( pub(crate) fn send_request<T, B>(
io: T, io: T,
mut head: RequestHeadType, mut head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
) -> Result<(ResponseHead, Payload), SendRequestError> ) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
B: MessageBody, B: MessageBody,
{ {
// set request host header // set request host header
@@ -65,99 +62,68 @@ where
io: Some(io), io: Some(io),
}; };
let len = body.size();
// create Framed and send request // create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default()); Framed::new(io, h1::ClientCodec::default())
framed.send((head, body.size()).into()).await?; .send((head, len).into())
.from_err()
// send request body // send request body
match body.size() { .and_then(move |framed| match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
_ => send_body(body, &mut framed).await?, Either::A(ok(framed))
}; }
_ => Either::B(SendBody::new(body, framed)),
// read response and init read body })
let res = framed.into_future().await; // read response and init read body
let (head, framed) = if let (Some(result), framed) = res { .and_then(|framed| {
let item = result.map_err(SendRequestError::from)?; framed
(item, framed) .into_future()
} else { .map_err(|(e, _)| SendRequestError::from(e))
return Err(SendRequestError::from(ConnectError::Disconnected)); .and_then(|(item, framed)| {
}; if let Some(res) = item {
match framed.get_codec().message_type() {
match framed.get_codec().message_type() { h1::MessageType::None => {
h1::MessageType::None => { let force_close = !framed.get_codec().keepalive();
let force_close = !framed.get_codec().keepalive(); release_connection(framed, force_close);
release_connection(framed, force_close); Ok((res, Payload::None))
Ok((head, Payload::None)) }
} _ => {
_ => { let pl: PayloadStream = Box::new(PlStream::new(framed));
let pl: PayloadStream = PlStream::new(framed).boxed_local(); Ok((res, pl.into()))
Ok((head, pl.into())) }
} }
} } else {
Err(ConnectError::Disconnected.into())
}
})
})
} }
pub(crate) async fn open_tunnel<T>( pub(crate) fn open_tunnel<T>(
io: T, io: T,
head: RequestHeadType, head: RequestHeadType,
) -> Result<(ResponseHead, Framed<T, h1::ClientCodec>), SendRequestError> ) -> impl Future<Item = (ResponseHead, Framed<T, h1::ClientCodec>), Error = SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
// create Framed and send request // create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default()); Framed::new(io, h1::ClientCodec::default())
framed.send((head, BodySize::None).into()).await?; .send((head, BodySize::None).into())
.from_err()
// read response // read response
if let (Some(result), framed) = framed.into_future().await { .and_then(|framed| {
let head = result.map_err(SendRequestError::from)?; framed
Ok((head, framed)) .into_future()
} else { .map_err(|(e, _)| SendRequestError::from(e))
Err(SendRequestError::from(ConnectError::Disconnected)) .and_then(|(head, framed)| {
} if let Some(head) = head {
} Ok((head, framed))
/// send request body to the peer
pub(crate) async fn send_body<I, B>(
mut body: B,
framed: &mut Framed<I, h1::ClientCodec>,
) -> Result<(), SendRequestError>
where
I: ConnectionLifetime,
B: MessageBody,
{
let mut eof = false;
while !eof {
while !eof && !framed.is_write_buf_full() {
match poll_fn(|cx| body.poll_next(cx)).await {
Some(result) => {
framed.write(h1::Message::Chunk(Some(result?)))?;
}
None => {
eof = true;
framed.write(h1::Message::Chunk(None))?;
}
}
}
if !framed.is_write_buf_empty() {
poll_fn(|cx| match framed.flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if !framed.is_write_buf_full() {
Poll::Ready(Ok(()))
} else { } else {
Poll::Pending Err(SendRequestError::from(ConnectError::Disconnected))
} }
} })
}) })
.await?;
}
}
SinkExt::flush(framed).await?;
Ok(())
} }
#[doc(hidden)] #[doc(hidden)]
@@ -168,10 +134,7 @@ pub struct H1Connection<T> {
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
} }
impl<T> ConnectionLifetime for H1Connection<T> impl<T: AsyncRead + AsyncWrite + 'static> ConnectionLifetime for H1Connection<T> {
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Close connection /// Close connection
fn close(&mut self) { fn close(&mut self) {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
@@ -199,41 +162,98 @@ where
} }
} }
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> { impl<T: AsyncRead + AsyncWrite + 'static> io::Read for H1Connection<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf) self.io.as_mut().unwrap().read(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
} }
} }
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T> { impl<T: AsyncRead + AsyncWrite + 'static> AsyncRead for H1Connection<T> {}
fn poll_write(
mut self: Pin<&mut Self>, impl<T: AsyncRead + AsyncWrite + 'static> io::Write for H1Connection<T> {
cx: &mut Context<'_>, fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
buf: &[u8], self.io.as_mut().unwrap().write(buf)
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf)
} }
fn poll_flush( fn flush(&mut self) -> io::Result<()> {
mut self: Pin<&mut Self>, self.io.as_mut().unwrap().flush()
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.io.as_mut().unwrap()).poll_flush(cx)
} }
}
fn poll_shutdown( impl<T: AsyncRead + AsyncWrite + 'static> AsyncWrite for H1Connection<T> {
mut self: Pin<&mut Self>, fn shutdown(&mut self) -> Poll<(), io::Error> {
cx: &mut Context, self.io.as_mut().unwrap().shutdown()
) -> Poll<Result<(), io::Error>> { }
Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) }
/// Future responsible for sending request body to the peer
pub(crate) struct SendBody<I, B> {
body: Option<B>,
framed: Option<Framed<I, h1::ClientCodec>>,
flushed: bool,
}
impl<I, B> SendBody<I, B>
where
I: AsyncRead + AsyncWrite + 'static,
B: MessageBody,
{
pub(crate) fn new(body: B, framed: Framed<I, h1::ClientCodec>) -> Self {
SendBody {
body: Some(body),
framed: Some(framed),
flushed: true,
}
}
}
impl<I, B> Future for SendBody<I, B>
where
I: ConnectionLifetime,
B: MessageBody,
{
type Item = Framed<I, h1::ClientCodec>;
type Error = SendRequestError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut body_ready = true;
loop {
while body_ready
&& self.body.is_some()
&& !self.framed.as_ref().unwrap().is_write_buf_full()
{
match self.body.as_mut().unwrap().poll_next()? {
Async::Ready(item) => {
// check if body is done
if item.is_none() {
let _ = self.body.take();
}
self.flushed = false;
self.framed
.as_mut()
.unwrap()
.force_send(h1::Message::Chunk(item))?;
break;
}
Async::NotReady => body_ready = false,
}
}
if !self.flushed {
match self.framed.as_mut().unwrap().poll_complete()? {
Async::Ready(_) => {
self.flushed = true;
continue;
}
Async::NotReady => return Ok(Async::NotReady),
}
}
if self.body.is_none() {
return Ok(Async::Ready(self.framed.take().unwrap()));
}
return Ok(Async::NotReady);
}
} }
} }
@@ -250,24 +270,23 @@ impl<Io: ConnectionLifetime> PlStream<Io> {
} }
impl<Io: ConnectionLifetime> Stream for PlStream<Io> { impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let this = self.get_mut(); match self.framed.as_mut().unwrap().poll()? {
Async::NotReady => Ok(Async::NotReady),
match this.framed.as_mut().unwrap().next_item(cx)? { Async::Ready(Some(chunk)) => {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
Poll::Ready(Some(Ok(chunk))) Ok(Async::Ready(Some(chunk)))
} else { } else {
let framed = this.framed.take().unwrap(); let framed = self.framed.take().unwrap();
let force_close = !framed.get_codec().keepalive(); let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close); release_connection(framed, force_close);
Poll::Ready(None) Ok(Async::Ready(None))
} }
} }
Poll::Ready(None) => Poll::Ready(None), Async::Ready(None) => Ok(Async::Ready(None)),
} }
} }
} }

View File

@@ -1,11 +1,9 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time; use std::time;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{err, poll_fn, Either}; use futures::future::{err, Either};
use futures::{Async, Future, Poll};
use h2::{client::SendRequest, SendStream}; use h2::{client::SendRequest, SendStream};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{request::Request, HttpTryFrom, Method, Version}; use http::{request::Request, HttpTryFrom, Method, Version};
@@ -19,15 +17,15 @@ use super::connection::{ConnectionType, IoConnection};
use super::error::SendRequestError; use super::error::SendRequestError;
use super::pool::Acquired; use super::pool::Acquired;
pub(crate) async fn send_request<T, B>( pub(crate) fn send_request<T, B>(
mut io: SendRequest<Bytes>, io: SendRequest<Bytes>,
head: RequestHeadType, head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
) -> Result<(ResponseHead, Payload), SendRequestError> ) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
B: MessageBody, B: MessageBody,
{ {
trace!("Sending client request: {:?} {:?}", head, body.size()); trace!("Sending client request: {:?} {:?}", head, body.size());
@@ -38,140 +36,158 @@ where
_ => false, _ => false,
}; };
let mut req = Request::new(()); io.ready()
*req.uri_mut() = head.as_ref().uri.clone(); .map_err(SendRequestError::from)
*req.method_mut() = head.as_ref().method.clone(); .and_then(move |mut io| {
*req.version_mut() = Version::HTTP_2; let mut req = Request::new(());
*req.uri_mut() = head.as_ref().uri.clone();
*req.method_mut() = head.as_ref().method.clone();
*req.version_mut() = Version::HTTP_2;
let mut skip_len = true; let mut skip_len = true;
// let mut has_date = false; // let mut has_date = false;
// Content length // Content length
let _ = match length { let _ = match length {
BodySize::None => None, BodySize::None => None,
BodySize::Stream => { BodySize::Stream => {
skip_len = false; skip_len = false;
None None
} }
BodySize::Empty => req BodySize::Empty => req
.headers_mut() .headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from_static("0")), .insert(CONTENT_LENGTH, HeaderValue::from_static("0")),
BodySize::Sized(len) => req.headers_mut().insert( BodySize::Sized(len) => req.headers_mut().insert(
CONTENT_LENGTH, CONTENT_LENGTH,
HeaderValue::try_from(format!("{}", len)).unwrap(), HeaderValue::try_from(format!("{}", len)).unwrap(),
), ),
BodySize::Sized64(len) => req.headers_mut().insert( BodySize::Sized64(len) => req.headers_mut().insert(
CONTENT_LENGTH, CONTENT_LENGTH,
HeaderValue::try_from(format!("{}", len)).unwrap(), HeaderValue::try_from(format!("{}", len)).unwrap(),
), ),
}; };
// Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate.
let (head, extra_headers) = match head { let (head, extra_headers) = match head {
RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), RequestHeadType::Owned(head) => {
RequestHeadType::Rc(head, extra_headers) => ( (RequestHeadType::Owned(head), HeaderMap::new())
RequestHeadType::Rc(head, None), }
extra_headers.unwrap_or_else(HeaderMap::new), RequestHeadType::Rc(head, extra_headers) => (
), RequestHeadType::Rc(head, None),
}; extra_headers.unwrap_or_else(HeaderMap::new),
),
};
// merging headers from head and extra headers. // merging headers from head and extra headers.
let headers = head let headers = head
.as_ref() .as_ref()
.headers .headers
.iter() .iter()
.filter(|(name, _)| !extra_headers.contains_key(*name)) .filter(|(name, _)| !extra_headers.contains_key(*name))
.chain(extra_headers.iter()); .chain(extra_headers.iter());
// copy headers // copy headers
for (key, value) in headers { for (key, value) in headers {
match *key { match *key {
CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONNECTION | TRANSFER_ENCODING => continue, // http2 specific
CONTENT_LENGTH if skip_len => continue, CONTENT_LENGTH if skip_len => continue,
// DATE => has_date = true, // DATE => has_date = true,
_ => (), _ => (),
} }
req.headers_mut().append(key, value.clone()); req.headers_mut().append(key, value.clone());
}
let res = poll_fn(|cx| io.poll_ready(cx)).await;
if let Err(e) = res {
release(io, pool, created, e.is_io());
return Err(SendRequestError::from(e));
}
let resp = match io.send_request(req, eof) {
Ok((fut, send)) => {
release(io, pool, created, false);
if !eof {
send_body(body, send).await?;
} }
fut.await.map_err(SendRequestError::from)?
}
Err(e) => {
release(io, pool, created, e.is_io());
return Err(e.into());
}
};
let (parts, body) = resp.into_parts(); match io.send_request(req, eof) {
let payload = if head_req { Payload::None } else { body.into() }; Ok((res, send)) => {
release(io, pool, created, false);
let mut head = ResponseHead::new(parts.status); if !eof {
head.version = parts.version; Either::A(Either::B(
head.headers = parts.headers.into(); SendBody {
Ok((head, payload)) body,
send,
buf: None,
}
.and_then(move |_| res.map_err(SendRequestError::from)),
))
} else {
Either::B(res.map_err(SendRequestError::from))
}
}
Err(e) => {
release(io, pool, created, e.is_io());
Either::A(Either::A(err(e.into())))
}
}
})
.and_then(move |resp| {
let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() };
let mut head = ResponseHead::new(parts.status);
head.version = parts.version;
head.headers = parts.headers.into();
Ok((head, payload))
})
.from_err()
} }
async fn send_body<B: MessageBody>( struct SendBody<B: MessageBody> {
mut body: B, body: B,
mut send: SendStream<Bytes>, send: SendStream<Bytes>,
) -> Result<(), SendRequestError> { buf: Option<Bytes>,
let mut buf = None; }
loop {
if buf.is_none() { impl<B: MessageBody> Future for SendBody<B> {
match poll_fn(|cx| body.poll_next(cx)).await { type Item = ();
Some(Ok(b)) => { type Error = SendRequestError;
send.reserve_capacity(b.len());
buf = Some(b); fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
if self.buf.is_none() {
match self.body.poll_next() {
Ok(Async::Ready(Some(buf))) => {
self.send.reserve_capacity(buf.len());
self.buf = Some(buf);
}
Ok(Async::Ready(None)) => {
if let Err(e) = self.send.send_data(Bytes::new(), true) {
return Err(e.into());
}
self.send.reserve_capacity(0);
return Ok(Async::Ready(()));
}
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => return Err(e.into()),
} }
Some(Err(e)) => return Err(e.into()), }
None => {
if let Err(e) = send.send_data(Bytes::new(), true) { match self.send.poll_capacity() {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(None)) => return Ok(Async::Ready(())),
Ok(Async::Ready(Some(cap))) => {
let mut buf = self.buf.take().unwrap();
let len = buf.len();
let bytes = buf.split_to(std::cmp::min(cap, len));
if let Err(e) = self.send.send_data(bytes, false) {
return Err(e.into()); return Err(e.into());
}
send.reserve_capacity(0);
return Ok(());
}
}
}
match poll_fn(|cx| send.poll_capacity(cx)).await {
None => return Ok(()),
Some(Ok(cap)) => {
let b = buf.as_mut().unwrap();
let len = b.len();
let bytes = b.split_to(std::cmp::min(cap, len));
if let Err(e) = send.send_data(bytes, false) {
return Err(e.into());
} else {
if !b.is_empty() {
send.reserve_capacity(b.len());
} else { } else {
buf = None; if !buf.is_empty() {
self.send.reserve_capacity(buf.len());
self.buf = Some(buf);
}
continue;
} }
continue;
} }
Err(e) => return Err(e.into()),
} }
Some(Err(e)) => return Err(e.into()),
} }
} }
} }
// release SendRequest object // release SendRequest object
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>( fn release<T: AsyncRead + AsyncWrite + 'static>(
io: SendRequest<Bytes>, io: SendRequest<Bytes>,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
created: time::Instant, created: time::Instant,

View File

@@ -1,23 +1,22 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::io; use std::io;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::Service; use actix_service::Service;
use actix_utils::{oneshot, task::LocalWaker};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{err, ok, poll_fn, Either, FutureExt, LocalBoxFuture, Ready}; use futures::future::{err, ok, Either, FutureResult};
use h2::client::{handshake, Connection, SendRequest}; use futures::task::AtomicTask;
use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use h2::client::{handshake, Handshake};
use hashbrown::HashMap; use hashbrown::HashMap;
use http::uri::Authority; use http::uri::Authority;
use indexmap::IndexSet; use indexmap::IndexSet;
use slab::Slab; use slab::Slab;
use tokio_timer::{delay_for, Delay}; use tokio_timer::{sleep, Delay};
use super::connection::{ConnectionType, IoConnection}; use super::connection::{ConnectionType, IoConnection};
use super::error::ConnectError; use super::error::ConnectError;
@@ -42,12 +41,16 @@ impl From<Authority> for Key {
} }
/// Connections pool /// Connections pool
pub(crate) struct ConnectionPool<T, Io: 'static>(Rc<RefCell<T>>, Rc<RefCell<Inner<Io>>>); pub(crate) struct ConnectionPool<T, Io: AsyncRead + AsyncWrite + 'static>(
T,
Rc<RefCell<Inner<Io>>>,
);
impl<T, Io> ConnectionPool<T, Io> impl<T, Io> ConnectionPool<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
pub(crate) fn new( pub(crate) fn new(
@@ -58,7 +61,7 @@ where
limit: usize, limit: usize,
) -> Self { ) -> Self {
ConnectionPool( ConnectionPool(
Rc::new(RefCell::new(connector)), connector,
Rc::new(RefCell::new(Inner { Rc::new(RefCell::new(Inner {
conn_lifetime, conn_lifetime,
conn_keep_alive, conn_keep_alive,
@@ -68,7 +71,7 @@ where
waiters: Slab::new(), waiters: Slab::new(),
waiters_queue: IndexSet::new(), waiters_queue: IndexSet::new(),
available: HashMap::new(), available: HashMap::new(),
waker: LocalWaker::new(), task: None,
})), })),
) )
} }
@@ -76,7 +79,8 @@ where
impl<T, Io> Clone for ConnectionPool<T, Io> impl<T, Io> Clone for ConnectionPool<T, Io>
where where
Io: 'static, T: Clone,
Io: AsyncRead + AsyncWrite + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
ConnectionPool(self.0.clone(), self.1.clone()) ConnectionPool(self.0.clone(), self.1.clone())
@@ -85,116 +89,86 @@ where
impl<T, Io> Service for ConnectionPool<T, Io> impl<T, Io> Service for ConnectionPool<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
type Response = IoConnection<Io>; type Response = IoConnection<Io>;
type Error = ConnectError; type Error = ConnectError;
type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>; type Future = Either<
FutureResult<Self::Response, Self::Error>,
Either<WaitForConnection<Io>, OpenConnection<T::Future, Io>>,
>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.0.poll_ready(cx) self.0.poll_ready()
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
// start support future let key = if let Some(authority) = req.uri.authority_part() {
tokio_executor::current_thread::spawn(ConnectorPoolSupport { authority.clone().into()
connector: self.0.clone(), } else {
inner: self.1.clone(), return Either::A(err(ConnectError::Unresolverd));
});
let mut connector = self.0.clone();
let inner = self.1.clone();
let fut = async move {
let key = if let Some(authority) = req.uri.authority_part() {
authority.clone().into()
} else {
return Err(ConnectError::Unresolverd);
};
// acquire connection
match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await {
Acquire::Acquired(io, created) => {
// use existing connection
return Ok(IoConnection::new(
io,
created,
Some(Acquired(key, Some(inner))),
));
}
Acquire::Available => {
// open tcp connection
let (io, proto) = connector.call(req).await?;
let guard = OpenGuard::new(key, inner);
if proto == Protocol::Http1 {
Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(guard.consume()),
))
} else {
let (snd, connection) = handshake(io).await?;
tokio_executor::current_thread::spawn(connection.map(|_| ()));
Ok(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(guard.consume()),
))
}
}
_ => {
// connection is not available, wait
let (rx, token) = inner.borrow_mut().wait_for(req);
let guard = WaiterGuard::new(key, token, inner);
let res = match rx.await {
Err(_) => Err(ConnectError::Disconnected),
Ok(res) => res,
};
guard.consume();
res
}
}
}; };
fut.boxed_local() // acquire connection
match self.1.as_ref().borrow_mut().acquire(&key) {
Acquire::Acquired(io, created) => {
// use existing connection
return Either::A(ok(IoConnection::new(
io,
created,
Some(Acquired(key, Some(self.1.clone()))),
)));
}
Acquire::Available => {
// open new connection
return Either::B(Either::B(OpenConnection::new(
key,
self.1.clone(),
self.0.call(req),
)));
}
_ => (),
}
// connection is not available, wait
let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req);
// start support future
if !support {
self.1.as_ref().borrow_mut().task = Some(AtomicTask::new());
tokio_current_thread::spawn(ConnectorPoolSupport {
connector: self.0.clone(),
inner: self.1.clone(),
})
}
Either::B(Either::A(WaitForConnection {
rx,
key,
token,
inner: Some(self.1.clone()),
}))
} }
} }
struct WaiterGuard<Io> #[doc(hidden)]
pub struct WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
key: Key, key: Key,
token: usize, token: usize,
rx: oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<Io> WaiterGuard<Io> impl<Io> Drop for WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{
fn new(key: Key, token: usize, inner: Rc<RefCell<Inner<Io>>>) -> Self {
Self {
key,
token,
inner: Some(inner),
}
}
fn consume(mut self) {
let _ = self.inner.take();
}
}
impl<Io> Drop for WaiterGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(i) = self.inner.take() { if let Some(i) = self.inner.take() {
@@ -205,43 +179,113 @@ where
} }
} }
struct OpenGuard<Io> impl<Io> Future for WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite,
{ {
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.rx.poll() {
Ok(Async::Ready(item)) => match item {
Err(err) => Err(err),
Ok(conn) => {
let _ = self.inner.take();
Ok(Async::Ready(conn))
}
},
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => {
let _ = self.inner.take();
Err(ConnectError::Disconnected)
}
}
}
}
#[doc(hidden)]
pub struct OpenConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
{
fut: F,
key: Key, key: Key,
h2: Option<Handshake<Io, Bytes>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<Io> OpenGuard<Io> impl<F, Io> OpenConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite + 'static,
{ {
fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>) -> Self { fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>, fut: F) -> Self {
Self { OpenConnection {
key, key,
fut,
inner: Some(inner), inner: Some(inner),
h2: None,
} }
} }
fn consume(mut self) -> Acquired<Io> {
Acquired(self.key.clone(), self.inner.take())
}
} }
impl<Io> Drop for OpenGuard<Io> impl<F, Io> Drop for OpenConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(i) = self.inner.take() { if let Some(inner) = self.inner.take() {
let mut inner = i.as_ref().borrow_mut(); let mut inner = inner.as_ref().borrow_mut();
inner.release(); inner.release();
inner.check_availibility(); inner.check_availibility();
} }
} }
} }
impl<F, Io> Future for OpenConnection<F, Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite,
{
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
Ok(Async::Ready((snd, connection))) => {
tokio_current_thread::spawn(connection.map_err(|_| ()));
Ok(Async::Ready(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e.into()),
};
}
match self.fut.poll() {
Err(err) => Err(err),
Ok(Async::Ready((io, proto))) => {
if proto == Protocol::Http1 {
Ok(Async::Ready(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
} else {
self.h2 = Some(handshake(io));
self.poll()
}
}
Ok(Async::NotReady) => Ok(Async::NotReady),
}
}
}
enum Acquire<T> { enum Acquire<T> {
Acquired(ConnectionType<T>, Instant), Acquired(ConnectionType<T>, Instant),
Available, Available,
@@ -268,7 +312,7 @@ pub(crate) struct Inner<Io> {
)>, )>,
>, >,
waiters_queue: IndexSet<(Key, usize)>, waiters_queue: IndexSet<(Key, usize)>,
waker: LocalWaker, task: Option<AtomicTask>,
} }
impl<Io> Inner<Io> { impl<Io> Inner<Io> {
@@ -288,7 +332,7 @@ impl<Io> Inner<Io> {
impl<Io> Inner<Io> impl<Io> Inner<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
/// connection is not available, wait /// connection is not available, wait
fn wait_for( fn wait_for(
@@ -297,6 +341,7 @@ where
) -> ( ) -> (
oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>, oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
usize, usize,
bool,
) { ) {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
@@ -306,10 +351,10 @@ where
entry.insert(Some((connect, tx))); entry.insert(Some((connect, tx)));
assert!(self.waiters_queue.insert((key, token))); assert!(self.waiters_queue.insert((key, token)));
(rx, token) (rx, token, self.task.is_some())
} }
fn acquire(&mut self, key: &Key, cx: &mut Context) -> Acquire<Io> { fn acquire(&mut self, key: &Key) -> Acquire<Io> {
// check limits // check limits
if self.limit > 0 && self.acquired >= self.limit { if self.limit > 0 && self.acquired >= self.limit {
return Acquire::NotAvailable; return Acquire::NotAvailable;
@@ -328,7 +373,7 @@ where
{ {
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = conn.io { if let ConnectionType::H1(io) = conn.io {
tokio_executor::current_thread::spawn(CloseConnection::new( tokio_current_thread::spawn(CloseConnection::new(
io, timeout, io, timeout,
)) ))
} }
@@ -337,19 +382,19 @@ where
let mut io = conn.io; let mut io = conn.io;
let mut buf = [0; 2]; let mut buf = [0; 2];
if let ConnectionType::H1(ref mut s) = io { if let ConnectionType::H1(ref mut s) = io {
match Pin::new(s).poll_read(cx, &mut buf) { match s.read(&mut buf) {
Poll::Pending => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Poll::Ready(Ok(n)) if n > 0 => { Ok(n) if n > 0 => {
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
tokio_executor::current_thread::spawn( tokio_current_thread::spawn(
CloseConnection::new(io, timeout), CloseConnection::new(io, timeout),
) )
} }
} }
continue; continue;
} }
_ => continue, Ok(_) | Err(_) => continue,
} }
} }
return Acquire::Acquired(io, conn.created); return Acquire::Acquired(io, conn.created);
@@ -376,7 +421,7 @@ where
self.acquired -= 1; self.acquired -= 1;
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
tokio_executor::current_thread::spawn(CloseConnection::new(io, timeout)) tokio_current_thread::spawn(CloseConnection::new(io, timeout))
} }
} }
self.check_availibility(); self.check_availibility();
@@ -384,7 +429,9 @@ where
fn check_availibility(&self) { fn check_availibility(&self) {
if !self.waiters_queue.is_empty() && self.acquired < self.limit { if !self.waiters_queue.is_empty() && self.acquired < self.limit {
self.waker.wake(); if let Some(t) = self.task.as_ref() {
t.notify()
}
} }
} }
} }
@@ -396,30 +443,29 @@ struct CloseConnection<T> {
impl<T> CloseConnection<T> impl<T> CloseConnection<T>
where where
T: AsyncWrite + Unpin, T: AsyncWrite,
{ {
fn new(io: T, timeout: Duration) -> Self { fn new(io: T, timeout: Duration) -> Self {
CloseConnection { CloseConnection {
io, io,
timeout: delay_for(timeout), timeout: sleep(timeout),
} }
} }
} }
impl<T> Future for CloseConnection<T> impl<T> Future for CloseConnection<T>
where where
T: AsyncWrite + Unpin, T: AsyncWrite,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { fn poll(&mut self) -> Poll<(), ()> {
let this = self.get_mut(); match self.timeout.poll() {
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
match Pin::new(&mut this.timeout).poll(cx) { Ok(Async::NotReady) => match self.io.shutdown() {
Poll::Ready(_) => Poll::Ready(()), Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) { Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}, },
} }
} }
@@ -427,7 +473,7 @@ where
struct ConnectorPoolSupport<T, Io> struct ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
connector: T, connector: T,
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
@@ -435,17 +481,16 @@ where
impl<T, Io> Future for ConnectorPoolSupport<T, Io> impl<T, Io> Future for ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>, T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>,
T::Future: 'static, T::Future: 'static,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = unsafe { self.get_unchecked_mut() }; let mut inner = self.inner.as_ref().borrow_mut();
inner.task.as_ref().unwrap().register();
let mut inner = this.inner.as_ref().borrow_mut();
inner.waker.register(cx.waker());
// check waiters // check waiters
loop { loop {
@@ -460,14 +505,14 @@ where
continue; continue;
} }
match inner.acquire(&key, cx) { match inner.acquire(&key) {
Acquire::NotAvailable => break, Acquire::NotAvailable => break,
Acquire::Acquired(io, created) => { Acquire::Acquired(io, created) => {
let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1;
if let Err(conn) = tx.send(Ok(IoConnection::new( if let Err(conn) = tx.send(Ok(IoConnection::new(
io, io,
created, created,
Some(Acquired(key.clone(), Some(this.inner.clone()))), Some(Acquired(key.clone(), Some(self.inner.clone()))),
))) { ))) {
let (io, created) = conn.unwrap().into_inner(); let (io, created) = conn.unwrap().into_inner();
inner.release_conn(&key, io, created); inner.release_conn(&key, io, created);
@@ -479,38 +524,33 @@ where
OpenWaitingConnection::spawn( OpenWaitingConnection::spawn(
key.clone(), key.clone(),
tx, tx,
this.inner.clone(), self.inner.clone(),
this.connector.call(connect), self.connector.call(connect),
); );
} }
} }
let _ = inner.waiters_queue.swap_remove_index(0); let _ = inner.waiters_queue.swap_remove_index(0);
} }
Poll::Pending Ok(Async::NotReady)
} }
} }
struct OpenWaitingConnection<F, Io> struct OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fut: F, fut: F,
key: Key, key: Key,
h2: Option< h2: Option<Handshake<Io, Bytes>>,
LocalBoxFuture<
'static,
Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>,
>,
>,
rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>, rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<F, Io> OpenWaitingConnection<F, Io> impl<F, Io> OpenWaitingConnection<F, Io>
where where
F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static, F: Future<Item = (Io, Protocol), Error = ConnectError> + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fn spawn( fn spawn(
key: Key, key: Key,
@@ -518,7 +558,7 @@ where
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
fut: F, fut: F,
) { ) {
tokio_executor::current_thread::spawn(OpenWaitingConnection { tokio_current_thread::spawn(OpenWaitingConnection {
key, key,
fut, fut,
h2: None, h2: None,
@@ -530,7 +570,7 @@ where
impl<F, Io> Drop for OpenWaitingConnection<F, Io> impl<F, Io> Drop for OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(inner) = self.inner.take() { if let Some(inner) = self.inner.take() {
@@ -543,60 +583,59 @@ where
impl<F, Io> Future for OpenWaitingConnection<F, Io> impl<F, Io> Future for OpenWaitingConnection<F, Io>
where where
F: Future<Output = Result<(Io, Protocol), ConnectError>>, F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite + Unpin, Io: AsyncRead + AsyncWrite,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = unsafe { self.get_unchecked_mut() }; if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
if let Some(ref mut h2) = this.h2 { Ok(Async::Ready((snd, connection))) => {
return match Pin::new(h2).poll(cx) { tokio_current_thread::spawn(connection.map_err(|_| ()));
Poll::Ready(Ok((snd, connection))) => { let rx = self.rx.take().unwrap();
tokio_executor::current_thread::spawn(connection.map(|_| ()));
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new( let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H2(snd), ConnectionType::H2(snd),
Instant::now(), Instant::now(),
Some(Acquired(this.key.clone(), this.inner.take())), Some(Acquired(self.key.clone(), self.inner.take())),
))); )));
Poll::Ready(()) Ok(Async::Ready(()))
} }
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(Err(err)) => { Err(err) => {
let _ = this.inner.take(); let _ = self.inner.take();
if let Some(rx) = this.rx.take() { if let Some(rx) = self.rx.take() {
let _ = rx.send(Err(ConnectError::H2(err))); let _ = rx.send(Err(ConnectError::H2(err)));
} }
Poll::Ready(()) Err(())
} }
}; };
} }
match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) { match self.fut.poll() {
Poll::Ready(Err(err)) => { Err(err) => {
let _ = this.inner.take(); let _ = self.inner.take();
if let Some(rx) = this.rx.take() { if let Some(rx) = self.rx.take() {
let _ = rx.send(Err(err)); let _ = rx.send(Err(err));
} }
Poll::Ready(()) Err(())
} }
Poll::Ready(Ok((io, proto))) => { Ok(Async::Ready((io, proto))) => {
if proto == Protocol::Http1 { if proto == Protocol::Http1 {
let rx = this.rx.take().unwrap(); let rx = self.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new( let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H1(io), ConnectionType::H1(io),
Instant::now(), Instant::now(),
Some(Acquired(this.key.clone(), this.inner.take())), Some(Acquired(self.key.clone(), self.inner.take())),
))); )));
Poll::Ready(()) Ok(Async::Ready(()))
} else { } else {
this.h2 = Some(handshake(io).boxed_local()); self.h2 = Some(handshake(io));
unsafe { Pin::new_unchecked(this) }.poll(cx) self.poll()
} }
} }
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
} }
} }
} }
@@ -605,7 +644,7 @@ pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
impl<T> Acquired<T> impl<T> Acquired<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
pub(crate) fn close(&mut self, conn: IoConnection<T>) { pub(crate) fn close(&mut self, conn: IoConnection<T>) {
if let Some(inner) = self.1.take() { if let Some(inner) = self.1.take() {

View File

@@ -1,8 +1,8 @@
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::Service; use actix_service::Service;
use futures::Poll;
#[doc(hidden)] #[doc(hidden)]
/// Service that allows to turn non-clone service to a service with `Clone` impl /// Service that allows to turn non-clone service to a service with `Clone` impl
@@ -32,8 +32,8 @@ where
type Error = T::Error; type Error = T::Error;
type Future = T::Future; type Future = T::Future;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx) unsafe { &mut *self.0.as_ref().get() }.poll_ready()
} }
fn call(&mut self, req: T::Request) -> Self::Future { fn call(&mut self, req: T::Request) -> Self::Future {

View File

@@ -5,9 +5,9 @@ use std::rc::Rc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::BytesMut; use bytes::BytesMut;
use futures::{future, Future, FutureExt}; use futures::{future, Future};
use time; use time;
use tokio_timer::{delay, delay_for, Delay}; use tokio_timer::{sleep, Delay};
// "Sun, 06 Nov 1994 08:49:37 GMT".len() // "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29; const DATE_VALUE_LENGTH: usize = 29;
@@ -104,10 +104,10 @@ impl ServiceConfig {
#[inline] #[inline]
/// Client timeout for first request. /// Client timeout for first request.
pub fn client_timer(&self) -> Option<Delay> { pub fn client_timer(&self) -> Option<Delay> {
let delay_time = self.0.client_timeout; let delay = self.0.client_timeout;
if delay_time != 0 { if delay != 0 {
Some(delay( Some(Delay::new(
self.0.timer.now() + Duration::from_millis(delay_time), self.0.timer.now() + Duration::from_millis(delay),
)) ))
} else { } else {
None None
@@ -138,7 +138,7 @@ impl ServiceConfig {
/// Return keep-alive timer delay is configured. /// Return keep-alive timer delay is configured.
pub fn keep_alive_timer(&self) -> Option<Delay> { pub fn keep_alive_timer(&self) -> Option<Delay> {
if let Some(ka) = self.0.keep_alive { if let Some(ka) = self.0.keep_alive {
Some(delay(self.0.timer.now() + ka)) Some(Delay::new(self.0.timer.now() + ka))
} else { } else {
None None
} }
@@ -242,12 +242,12 @@ impl DateService {
// periodic date update // periodic date update
let s = self.clone(); let s = self.clone();
tokio_executor::current_thread::spawn( tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then(
delay_for(Duration::from_millis(500)).then(move |_| { move |_| {
s.0.reset(); s.0.reset();
future::ready(()) future::ok(())
}), },
); ));
} }
} }
@@ -277,7 +277,7 @@ mod tests {
fn test_date() { fn test_date() {
let mut rt = System::new("test"); let mut rt = System::new("test");
let _ = rt.block_on(future::lazy(|_| { let _ = rt.block_on(future::lazy(|| {
let settings = ServiceConfig::new(KeepAlive::Os, 0, 0); let settings = ServiceConfig::new(KeepAlive::Os, 0, 0);
let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10);
settings.set_date(&mut buf1); settings.set_date(&mut buf1);

View File

@@ -1,12 +1,13 @@
use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256}; use ring::digest::{Algorithm, SHA256};
use ring::hmac; use ring::hkdf::expand;
use ring::hmac::SigningKey;
use ring::rand::{SecureRandom, SystemRandom}; use ring::rand::{SecureRandom, SystemRandom};
use super::private::KEY_LEN as PRIVATE_KEY_LEN; use super::private::KEY_LEN as PRIVATE_KEY_LEN;
use super::signed::KEY_LEN as SIGNED_KEY_LEN; use super::signed::KEY_LEN as SIGNED_KEY_LEN;
static HKDF_DIGEST: Algorithm = HKDF_SHA256; static HKDF_DIGEST: &Algorithm = &SHA256;
const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"]; const KEYS_INFO: &str = "COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM";
/// A cryptographic master key for use with `Signed` and/or `Private` jars. /// A cryptographic master key for use with `Signed` and/or `Private` jars.
/// ///
@@ -24,13 +25,6 @@ pub struct Key {
encryption_key: [u8; PRIVATE_KEY_LEN], encryption_key: [u8; PRIVATE_KEY_LEN],
} }
impl KeyType for &Key {
#[inline]
fn len(&self) -> usize {
SIGNED_KEY_LEN + PRIVATE_KEY_LEN
}
}
impl Key { impl Key {
/// Derives new signing/encryption keys from a master key. /// Derives new signing/encryption keys from a master key.
/// ///
@@ -62,26 +56,21 @@ impl Key {
); );
} }
// An empty `Key` structure; will be filled in with HKDF derived keys. // Expand the user's key into two.
let mut output_key = Key { let prk = SigningKey::new(HKDF_DIGEST, key);
signing_key: [0; SIGNED_KEY_LEN],
encryption_key: [0; PRIVATE_KEY_LEN],
};
// Expand the master key into two HKDF generated keys.
let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN]; let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN];
let prk = Prk::new_less_safe(HKDF_DIGEST, key); expand(&prk, KEYS_INFO.as_bytes(), &mut both_keys);
let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand");
okm.fill(&mut both_keys).expect("fill keys");
// Copy the key parts into their respective fields. // Copy the keys into their respective arrays.
output_key let mut signing_key = [0; SIGNED_KEY_LEN];
.signing_key let mut encryption_key = [0; PRIVATE_KEY_LEN];
.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); signing_key.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]);
output_key encryption_key.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]);
.encryption_key
.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); Key {
output_key signing_key,
encryption_key,
}
} }
/// Generates signing/encryption keys from a secure, random source. Keys are /// Generates signing/encryption keys from a secure, random source. Keys are

View File

@@ -1,8 +1,8 @@
use std::str; use std::str;
use log::warn; use log::warn;
use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM}; use ring::aead::{open_in_place, seal_in_place, Aad, Algorithm, Nonce, AES_256_GCM};
use ring::aead::{LessSafeKey, UnboundKey}; use ring::aead::{OpeningKey, SealingKey};
use ring::rand::{SecureRandom, SystemRandom}; use ring::rand::{SecureRandom, SystemRandom};
use super::Key; use super::Key;
@@ -10,7 +10,7 @@ use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `private` docs as // Keep these in sync, and keep the key len synced with the `private` docs as
// well as the `KEYS_INFO` const in secure::Key. // well as the `KEYS_INFO` const in secure::Key.
static ALGO: &'static Algorithm = &AES_256_GCM; static ALGO: &Algorithm = &AES_256_GCM;
const NONCE_LEN: usize = 12; const NONCE_LEN: usize = 12;
pub const KEY_LEN: usize = 32; pub const KEY_LEN: usize = 32;
@@ -53,14 +53,11 @@ impl<'a> PrivateJar<'a> {
} }
let ad = Aad::from(name.as_bytes()); let ad = Aad::from(name.as_bytes());
let key = LessSafeKey::new( let key = OpeningKey::new(ALGO, &self.key).expect("opening key");
UnboundKey::new(&ALGO, &self.key).expect("matching key length"), let (nonce, sealed) = data.split_at_mut(NONCE_LEN);
);
let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN);
let nonce = let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
let unsealed = key let unsealed = open_in_place(&key, nonce, ad, 0, sealed)
.open_in_place(nonce, ad, &mut sealed)
.map_err(|_| "invalid key/nonce/value: bad seal")?; .map_err(|_| "invalid key/nonce/value: bad seal")?;
if let Ok(unsealed_utf8) = str::from_utf8(unsealed) { if let Ok(unsealed_utf8) = str::from_utf8(unsealed) {
@@ -199,33 +196,30 @@ Please change it as soon as possible."
fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> { fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> {
// Create the `SealingKey` structure. // Create the `SealingKey` structure.
let unbound = UnboundKey::new(&ALGO, key).expect("matching key length"); let key = SealingKey::new(ALGO, key).expect("sealing key creation");
let key = LessSafeKey::new(unbound);
// Create a vec to hold the [nonce | cookie value | overhead]. // Create a vec to hold the [nonce | cookie value | overhead].
let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()]; let overhead = ALGO.tag_len();
let mut data = vec![0; NONCE_LEN + value.len() + overhead];
// Randomly generate the nonce, then copy the cookie value as input. // Randomly generate the nonce, then copy the cookie value as input.
let (nonce, in_out) = data.split_at_mut(NONCE_LEN); let (nonce, in_out) = data.split_at_mut(NONCE_LEN);
let (in_out, tag) = in_out.split_at_mut(value.len());
in_out.copy_from_slice(value);
// Randomly generate the nonce into the nonce piece.
SystemRandom::new() SystemRandom::new()
.fill(nonce) .fill(nonce)
.expect("couldn't random fill nonce"); .expect("couldn't random fill nonce");
let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length"); in_out[..value.len()].copy_from_slice(value);
let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
// Use cookie's name as associated data to prevent value swapping. // Use cookie's name as associated data to prevent value swapping.
let ad = Aad::from(name); let ad = Aad::from(name);
let ad_tag = key
.seal_in_place_separate_tag(nonce, ad, in_out)
.expect("in-place seal");
// Copy the tag into the tag piece. // Perform the actual sealing operation and get the output length.
tag.copy_from_slice(ad_tag.as_ref()); let output_len =
seal_in_place(&key, nonce, ad, in_out, overhead).expect("in-place seal");
// Remove the overhead and return the sealed content. // Remove the overhead and return the sealed content.
data.truncate(NONCE_LEN + output_len);
data data
} }

View File

@@ -1,11 +1,12 @@
use ring::hmac::{self, sign, verify}; use ring::digest::{Algorithm, SHA256};
use ring::hmac::{sign, verify_with_own_key as verify, SigningKey};
use super::Key; use super::Key;
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `signed` docs as // Keep these in sync, and keep the key len synced with the `signed` docs as
// well as the `KEYS_INFO` const in secure::Key. // well as the `KEYS_INFO` const in secure::Key.
static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256; static HMAC_DIGEST: &Algorithm = &SHA256;
const BASE64_DIGEST_LEN: usize = 44; const BASE64_DIGEST_LEN: usize = 44;
pub const KEY_LEN: usize = 32; pub const KEY_LEN: usize = 32;
@@ -20,7 +21,7 @@ pub const KEY_LEN: usize = 32;
/// This type is only available when the `secure` feature is enabled. /// This type is only available when the `secure` feature is enabled.
pub struct SignedJar<'a> { pub struct SignedJar<'a> {
parent: &'a mut CookieJar, parent: &'a mut CookieJar,
key: hmac::Key, key: SigningKey,
} }
impl<'a> SignedJar<'a> { impl<'a> SignedJar<'a> {
@@ -31,7 +32,7 @@ impl<'a> SignedJar<'a> {
pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> { pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> {
SignedJar { SignedJar {
parent, parent,
key: hmac::Key::new(HMAC_DIGEST, key.signing()), key: SigningKey::new(HMAC_DIGEST, key.signing()),
} }
} }

View File

@@ -1,7 +1,4 @@
use std::future::Future;
use std::io::{self, Write}; use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture}; use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
@@ -9,7 +6,7 @@ use brotli2::write::BrotliDecoder;
use bytes::Bytes; use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzDecoder, ZlibDecoder}; use flate2::write::{GzDecoder, ZlibDecoder};
use futures::{ready, Stream}; use futures::{try_ready, Async, Future, Poll, Stream};
use super::Writer; use super::Writer;
use crate::error::PayloadError; use crate::error::PayloadError;
@@ -21,12 +18,12 @@ pub struct Decoder<S> {
decoder: Option<ContentDecoder>, decoder: Option<ContentDecoder>,
stream: S, stream: S,
eof: bool, eof: bool,
fut: Option<CpuFuture<Result<(Option<Bytes>, ContentDecoder), io::Error>>>, fut: Option<CpuFuture<(Option<Bytes>, ContentDecoder), io::Error>>,
} }
impl<S> Decoder<S> impl<S> Decoder<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>>, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
/// Construct a decoder. /// Construct a decoder.
#[inline] #[inline]
@@ -74,41 +71,34 @@ where
impl<S> Stream for Decoder<S> impl<S> Stream for Decoder<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next( fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
loop { loop {
if let Some(ref mut fut) = self.fut { if let Some(ref mut fut) = self.fut {
let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) { let (chunk, decoder) = try_ready!(fut.poll());
Ok(Ok(item)) => item,
Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
self.decoder = Some(decoder); self.decoder = Some(decoder);
self.fut.take(); self.fut.take();
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
if self.eof { if self.eof {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
match Pin::new(&mut self.stream).poll_next(cx) { match self.stream.poll()? {
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), Async::Ready(Some(chunk)) => {
Poll::Ready(Some(Ok(chunk))) => {
if let Some(mut decoder) = self.decoder.take() { if let Some(mut decoder) = self.decoder.take() {
if chunk.len() < INPLACE { if chunk.len() < INPLACE {
let chunk = decoder.feed_data(chunk)?; let chunk = decoder.feed_data(chunk)?;
self.decoder = Some(decoder); self.decoder = Some(decoder);
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} else { } else {
self.fut = Some(run(move || { self.fut = Some(run(move || {
@@ -118,25 +108,21 @@ where
} }
continue; continue;
} else { } else {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
Poll::Ready(None) => { Async::Ready(None) => {
self.eof = true; self.eof = true;
return if let Some(mut decoder) = self.decoder.take() { return if let Some(mut decoder) = self.decoder.take() {
match decoder.feed_eof() { Ok(Async::Ready(decoder.feed_eof()?))
Ok(Some(res)) => Poll::Ready(Some(Ok(res))),
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err.into()))),
}
} else { } else {
Poll::Ready(None) Ok(Async::Ready(None))
}; };
} }
Poll::Pending => break, Async::NotReady => break,
} }
} }
Poll::Pending Ok(Async::NotReady)
} }
} }

View File

@@ -1,8 +1,5 @@
//! Stream encoder //! Stream encoder
use std::future::Future;
use std::io::{self, Write}; use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture}; use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
@@ -10,6 +7,7 @@ use brotli2::write::BrotliEncoder;
use bytes::Bytes; use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzEncoder, ZlibEncoder}; use flate2::write::{GzEncoder, ZlibEncoder};
use futures::{Async, Future, Poll};
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; use crate::http::header::{ContentEncoding, CONTENT_ENCODING};
@@ -24,7 +22,7 @@ pub struct Encoder<B> {
eof: bool, eof: bool,
body: EncoderBody<B>, body: EncoderBody<B>,
encoder: Option<ContentEncoder>, encoder: Option<ContentEncoder>,
fut: Option<CpuFuture<Result<ContentEncoder, io::Error>>>, fut: Option<CpuFuture<ContentEncoder, io::Error>>,
} }
impl<B: MessageBody> Encoder<B> { impl<B: MessageBody> Encoder<B> {
@@ -96,46 +94,43 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
} }
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
loop { loop {
if self.eof { if self.eof {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
if let Some(ref mut fut) = self.fut { if let Some(ref mut fut) = self.fut {
let mut encoder = match futures::ready!(Pin::new(fut).poll(cx)) { let mut encoder = futures::try_ready!(fut.poll());
Ok(Ok(item)) => item,
Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let chunk = encoder.take(); let chunk = encoder.take();
self.encoder = Some(encoder); self.encoder = Some(encoder);
self.fut.take(); self.fut.take();
if !chunk.is_empty() { if !chunk.is_empty() {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
let result = match self.body { let result = match self.body {
EncoderBody::Bytes(ref mut b) => { EncoderBody::Bytes(ref mut b) => {
if b.is_empty() { if b.is_empty() {
Poll::Ready(None) Async::Ready(None)
} else { } else {
Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new())))) Async::Ready(Some(std::mem::replace(b, Bytes::new())))
} }
} }
EncoderBody::Stream(ref mut b) => b.poll_next(cx), EncoderBody::Stream(ref mut b) => b.poll_next()?,
EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx), EncoderBody::BoxedStream(ref mut b) => b.poll_next()?,
}; };
match result { match result {
Poll::Ready(Some(Ok(chunk))) => { Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(chunk)) => {
if let Some(mut encoder) = self.encoder.take() { if let Some(mut encoder) = self.encoder.take() {
if chunk.len() < INPLACE { if chunk.len() < INPLACE {
encoder.write(&chunk)?; encoder.write(&chunk)?;
let chunk = encoder.take(); let chunk = encoder.take();
self.encoder = Some(encoder); self.encoder = Some(encoder);
if !chunk.is_empty() { if !chunk.is_empty() {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} else { } else {
self.fut = Some(run(move || { self.fut = Some(run(move || {
@@ -144,23 +139,22 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
})); }));
} }
} else { } else {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
Poll::Ready(None) => { Async::Ready(None) => {
if let Some(encoder) = self.encoder.take() { if let Some(encoder) = self.encoder.take() {
let chunk = encoder.finish()?; let chunk = encoder.finish()?;
if chunk.is_empty() { if chunk.is_empty() {
return Poll::Ready(None); return Ok(Async::Ready(None));
} else { } else {
self.eof = true; self.eof = true;
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} else { } else {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
} }
val => return val,
} }
} }
} }

View File

@@ -6,10 +6,11 @@ use std::str::Utf8Error;
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use std::{fmt, io, result}; use std::{fmt, io, result};
pub use actix_threadpool::BlockingError;
use actix_utils::timeout::TimeoutError; use actix_utils::timeout::TimeoutError;
use bytes::BytesMut; use bytes::BytesMut;
use derive_more::{Display, From}; use derive_more::{Display, From};
pub use futures::channel::oneshot::Canceled; use futures::Canceled;
use http::uri::InvalidUri; use http::uri::InvalidUri;
use http::{header, Error as HttpError, StatusCode}; use http::{header, Error as HttpError, StatusCode};
use httparse; use httparse;
@@ -107,7 +108,7 @@ impl fmt::Display for Error {
impl fmt::Debug for Error { impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "{:?}", &self.cause) write!(f, "{:?}", &self.cause)
} }
} }
@@ -181,13 +182,13 @@ impl ResponseError for FormError {}
/// `InternalServerError` for `TimerError` /// `InternalServerError` for `TimerError`
impl ResponseError for TimerError {} impl ResponseError for TimerError {}
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
/// `InternalServerError` for `openssl::ssl::Error` /// `InternalServerError` for `openssl::ssl::Error`
impl ResponseError for open_ssl::ssl::Error {} impl ResponseError for openssl::ssl::Error {}
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
/// `InternalServerError` for `openssl::ssl::HandshakeError` /// `InternalServerError` for `openssl::ssl::HandshakeError`
impl<T: std::fmt::Debug> ResponseError for open_ssl::ssl::HandshakeError<T> {} impl ResponseError for openssl::ssl::HandshakeError<tokio_tcp::TcpStream> {}
/// Return `BAD_REQUEST` for `de::value::Error` /// Return `BAD_REQUEST` for `de::value::Error`
impl ResponseError for DeError { impl ResponseError for DeError {
@@ -196,8 +197,8 @@ impl ResponseError for DeError {
} }
} }
/// `InternalServerError` for `Canceled` /// `InternalServerError` for `BlockingError`
impl ResponseError for Canceled {} impl<E: fmt::Debug> ResponseError for BlockingError<E> {}
/// Return `BAD_REQUEST` for `Utf8Error` /// Return `BAD_REQUEST` for `Utf8Error`
impl ResponseError for Utf8Error { impl ResponseError for Utf8Error {
@@ -235,6 +236,9 @@ impl ResponseError for header::InvalidHeaderValueBytes {
} }
} }
/// `InternalServerError` for `futures::Canceled`
impl ResponseError for Canceled {}
/// A set of errors that can occur during parsing HTTP streams /// A set of errors that can occur during parsing HTTP streams
#[derive(Debug, Display)] #[derive(Debug, Display)]
pub enum ParseError { pub enum ParseError {
@@ -361,12 +365,15 @@ impl From<io::Error> for PayloadError {
} }
} }
impl From<Canceled> for PayloadError { impl From<BlockingError<io::Error>> for PayloadError {
fn from(_: Canceled) -> Self { fn from(err: BlockingError<io::Error>) -> Self {
PayloadError::Io(io::Error::new( match err {
io::ErrorKind::Other, BlockingError::Error(e) => PayloadError::Io(e),
"Operation is canceled", BlockingError::Canceled => PayloadError::Io(io::Error::new(
)) io::ErrorKind::Other,
"Thread pool is gone",
)),
}
} }
} }

View File

@@ -1,12 +1,10 @@
use std::future::Future;
use std::io; use std::io;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::Decoder; use actix_codec::Decoder;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{HeaderName, HeaderValue}; use http::header::{HeaderName, HeaderValue};
use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version};
use httparse; use httparse;
@@ -444,10 +442,9 @@ impl Decoder for PayloadDecoder {
loop { loop {
let mut buf = None; let mut buf = None;
// advances the chunked state // advances the chunked state
*state = match state.step(src, size, &mut buf) { *state = match state.step(src, size, &mut buf)? {
Poll::Pending => return Ok(None), Async::NotReady => return Ok(None),
Poll::Ready(Ok(state)) => state, Async::Ready(state) => state,
Poll::Ready(Err(e)) => return Err(e),
}; };
if *state == ChunkedState::End { if *state == ChunkedState::End {
trace!("End of chunked stream"); trace!("End of chunked stream");
@@ -479,7 +476,7 @@ macro_rules! byte (
$rdr.split_to(1); $rdr.split_to(1);
b b
} else { } else {
return Poll::Pending return Ok(Async::NotReady)
} }
}) })
); );
@@ -490,7 +487,7 @@ impl ChunkedState {
body: &mut BytesMut, body: &mut BytesMut,
size: &mut u64, size: &mut u64,
buf: &mut Option<Bytes>, buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> { ) -> Poll<ChunkedState, io::Error> {
use self::ChunkedState::*; use self::ChunkedState::*;
match *self { match *self {
Size => ChunkedState::read_size(body, size), Size => ChunkedState::read_size(body, size),
@@ -502,14 +499,10 @@ impl ChunkedState {
BodyLf => ChunkedState::read_body_lf(body), BodyLf => ChunkedState::read_body_lf(body),
EndCr => ChunkedState::read_end_cr(body), EndCr => ChunkedState::read_end_cr(body),
EndLf => ChunkedState::read_end_lf(body), EndLf => ChunkedState::read_end_lf(body),
End => Poll::Ready(Ok(ChunkedState::End)), End => Ok(Async::Ready(ChunkedState::End)),
} }
} }
fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll<ChunkedState, io::Error> {
fn read_size(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
let radix = 16; let radix = 16;
match byte!(rdr) { match byte!(rdr) {
b @ b'0'..=b'9' => { b @ b'0'..=b'9' => {
@@ -524,49 +517,48 @@ impl ChunkedState {
*size *= radix; *size *= radix;
*size += u64::from(b + 10 - b'A'); *size += u64::from(b + 10 - b'A');
} }
b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => return Poll::Ready(Ok(ChunkedState::Extension)), b';' => return Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)),
_ => { _ => {
return Poll::Ready(Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size", "Invalid chunk size line: Invalid Size",
))); ));
} }
} }
Poll::Ready(Ok(ChunkedState::Size)) Ok(Async::Ready(ChunkedState::Size))
} }
fn read_size_lws(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_size_lws"); trace!("read_size_lws");
match byte!(rdr) { match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come // LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => Poll::Ready(Ok(ChunkedState::Extension)), b';' => Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space", "Invalid chunk size linear white space",
))), )),
} }
} }
fn read_extension(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_extension(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions
} }
} }
fn read_size_lf( fn read_size_lf(
rdr: &mut BytesMut, rdr: &mut BytesMut,
size: &mut u64, size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> { ) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)),
b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size LF", "Invalid chunk size LF",
))), )),
} }
} }
@@ -574,12 +566,12 @@ impl ChunkedState {
rdr: &mut BytesMut, rdr: &mut BytesMut,
rem: &mut u64, rem: &mut u64,
buf: &mut Option<Bytes>, buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> { ) -> Poll<ChunkedState, io::Error> {
trace!("Chunked read, remaining={:?}", rem); trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64; let len = rdr.len() as u64;
if len == 0 { if len == 0 {
Poll::Ready(Ok(ChunkedState::Body)) Ok(Async::Ready(ChunkedState::Body))
} else { } else {
let slice; let slice;
if *rem > len { if *rem > len {
@@ -591,47 +583,47 @@ impl ChunkedState {
} }
*buf = Some(slice); *buf = Some(slice);
if *rem > 0 { if *rem > 0 {
Poll::Ready(Ok(ChunkedState::Body)) Ok(Async::Ready(ChunkedState::Body))
} else { } else {
Poll::Ready(Ok(ChunkedState::BodyCr)) Ok(Async::Ready(ChunkedState::BodyCr))
} }
} }
} }
fn read_body_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_body_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk body CR", "Invalid chunk body CR",
))), )),
} }
} }
fn read_body_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_body_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::Size)), b'\n' => Ok(Async::Ready(ChunkedState::Size)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk body LF", "Invalid chunk body LF",
))), )),
} }
} }
fn read_end_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_end_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), b'\r' => Ok(Async::Ready(ChunkedState::EndLf)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk end CR", "Invalid chunk end CR",
))), )),
} }
} }
fn read_end_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_end_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::End)), b'\n' => Ok(Async::Ready(ChunkedState::End)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk end LF", "Invalid chunk end LF",
))), )),
} }
} }
} }

View File

@@ -1,17 +1,15 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant; use std::time::Instant;
use std::{fmt, io, io::Write, net}; use std::{fmt, io, net};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_codec::{Decoder, Encoder, Framed, FramedParts};
use actix_server_config::IoStream; use actix_server_config::IoStream;
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use futures::{Async, Future, Poll};
use log::{error, trace}; use log::{error, trace};
use tokio_timer::{delay, Delay}; use tokio_timer::Delay;
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
@@ -263,14 +261,14 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
fn can_read(&self, cx: &mut Context) -> bool { fn can_read(&self) -> bool {
if self if self
.flags .flags
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
{ {
false false
} else if let Some(ref info) = self.payload { } else if let Some(ref info) = self.payload {
info.need_read(cx) == PayloadStatus::Read info.need_read() == PayloadStatus::Read
} else { } else {
true true
} }
@@ -289,7 +287,7 @@ where
/// ///
/// true - got whouldblock /// true - got whouldblock
/// false - didnt get whouldblock /// false - didnt get whouldblock
fn poll_flush(&mut self, cx: &mut Context) -> Result<bool, DispatchError> { fn poll_flush(&mut self) -> Result<bool, DispatchError> {
if self.write_buf.is_empty() { if self.write_buf.is_empty() {
return Ok(false); return Ok(false);
} }
@@ -297,31 +295,31 @@ where
let len = self.write_buf.len(); let len = self.write_buf.len();
let mut written = 0; let mut written = 0;
while written < len { while written < len {
match unsafe { Pin::new_unchecked(&mut self.io) } match self.io.write(&self.write_buf[written..]) {
.poll_write(cx, &self.write_buf[written..]) Ok(0) => {
{
Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new( return Err(DispatchError::Io(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"", "",
))); )));
} }
Poll::Ready(Ok(n)) => { Ok(n) => {
written += n; written += n;
} }
Poll::Pending => { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if written > 0 { if written > 0 {
let _ = self.write_buf.split_to(written); let _ = self.write_buf.split_to(written);
} }
return Ok(true); return Ok(true);
} }
Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), Err(err) => return Err(DispatchError::Io(err)),
} }
} }
if written == self.write_buf.len() { if written > 0 {
unsafe { self.write_buf.set_len(0) } if written == self.write_buf.len() {
} else { unsafe { self.write_buf.set_len(0) }
let _ = self.write_buf.split_to(written); } else {
let _ = self.write_buf.split_to(written);
}
} }
Ok(false) Ok(false)
} }
@@ -352,15 +350,12 @@ where
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
} }
fn poll_response( fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
&mut self,
cx: &mut Context,
) -> Result<PollResponse, DispatchError> {
loop { loop {
let state = match self.state { let state = match self.state {
State::None => match self.messages.pop_front() { State::None => match self.messages.pop_front() {
Some(DispatcherMessage::Item(req)) => { Some(DispatcherMessage::Item(req)) => {
Some(self.handle_request(req, cx)?) Some(self.handle_request(req)?)
} }
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
@@ -370,58 +365,54 @@ where
} }
None => None, None => None,
}, },
State::ExpectCall(ref mut fut) => { State::ExpectCall(ref mut fut) => match fut.poll() {
match unsafe { Pin::new_unchecked(fut) }.poll(cx) { Ok(Async::Ready(req)) => {
Poll::Ready(Ok(req)) => { self.send_continue();
self.send_continue(); self.state = State::ServiceCall(self.service.call(req));
self.state = State::ServiceCall(self.service.call(req)); continue;
continue;
}
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
Poll::Pending => None,
} }
} Ok(Async::NotReady) => None,
State::ServiceCall(ref mut fut) => { Err(e) => {
match unsafe { Pin::new_unchecked(fut) }.poll(cx) { let res: Response = e.into().into();
Poll::Ready(Ok(res)) => { let (res, body) = res.replace_body(());
let (res, body) = res.into().replace_body(()); Some(self.send_response(res, body.into_body())?)
self.state = self.send_response(res, body)?;
continue;
}
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
Poll::Pending => None,
} }
} },
State::ServiceCall(ref mut fut) => match fut.poll() {
Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(());
self.state = self.send_response(res, body)?;
continue;
}
Ok(Async::NotReady) => None,
Err(e) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
},
State::SendPayload(ref mut stream) => { State::SendPayload(ref mut stream) => {
loop { loop {
if self.write_buf.len() < HW_BUFFER_SIZE { if self.write_buf.len() < HW_BUFFER_SIZE {
match stream.poll_next(cx) { match stream
Poll::Ready(Some(Ok(item))) => { .poll_next()
.map_err(|_| DispatchError::Unknown)?
{
Async::Ready(Some(item)) => {
self.codec.encode( self.codec.encode(
Message::Chunk(Some(item)), Message::Chunk(Some(item)),
&mut self.write_buf, &mut self.write_buf,
)?; )?;
continue; continue;
} }
Poll::Ready(None) => { Async::Ready(None) => {
self.codec.encode( self.codec.encode(
Message::Chunk(None), Message::Chunk(None),
&mut self.write_buf, &mut self.write_buf,
)?; )?;
self.state = State::None; self.state = State::None;
} }
Poll::Ready(Some(Err(_))) => { Async::NotReady => return Ok(PollResponse::DoNothing),
return Err(DispatchError::Unknown)
}
Poll::Pending => return Ok(PollResponse::DoNothing),
} }
} else { } else {
return Ok(PollResponse::DrainWriteBuf); return Ok(PollResponse::DrainWriteBuf);
@@ -442,7 +433,7 @@ where
// if read-backpressure is enabled and we consumed some data. // if read-backpressure is enabled and we consumed some data.
// we may read more data and retry // we may read more data and retry
if self.state.is_call() { if self.state.is_call() {
if self.poll_request(cx)? { if self.poll_request()? {
continue; continue;
} }
} else if !self.messages.is_empty() { } else if !self.messages.is_empty() {
@@ -455,21 +446,17 @@ where
Ok(PollResponse::DoNothing) Ok(PollResponse::DoNothing)
} }
fn handle_request( fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
&mut self,
req: Request,
cx: &mut Context,
) -> Result<State<S, B, X>, DispatchError> {
// Handle `EXPECT: 100-Continue` header // Handle `EXPECT: 100-Continue` header
let req = if req.head().expect() { let req = if req.head().expect() {
let mut task = self.expect.call(req); let mut task = self.expect.call(req);
match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { match task.poll() {
Poll::Ready(Ok(req)) => { Ok(Async::Ready(req)) => {
self.send_continue(); self.send_continue();
req req
} }
Poll::Pending => return Ok(State::ExpectCall(task)), Ok(Async::NotReady) => return Ok(State::ExpectCall(task)),
Poll::Ready(Err(e)) => { Err(e) => {
let e = e.into(); let e = e.into();
let res: Response = e.into(); let res: Response = e.into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
@@ -482,13 +469,13 @@ where
// Call service // Call service
let mut task = self.service.call(req); let mut task = self.service.call(req);
match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { match task.poll() {
Poll::Ready(Ok(res)) => { Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.send_response(res, body) self.send_response(res, body)
} }
Poll::Pending => Ok(State::ServiceCall(task)), Ok(Async::NotReady) => Ok(State::ServiceCall(task)),
Poll::Ready(Err(e)) => { Err(e) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
self.send_response(res, body.into_body()) self.send_response(res, body.into_body())
@@ -497,12 +484,9 @@ where
} }
/// Process one incoming requests /// Process one incoming requests
pub(self) fn poll_request( pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError> {
&mut self,
cx: &mut Context,
) -> Result<bool, DispatchError> {
// limit a mount of non processed requests // limit a mount of non processed requests
if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) { if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() {
return Ok(false); return Ok(false);
} }
@@ -537,7 +521,7 @@ where
// handle request early // handle request early
if self.state.is_empty() { if self.state.is_empty() {
self.state = self.handle_request(req, cx)?; self.state = self.handle_request(req)?;
} else { } else {
self.messages.push_back(DispatcherMessage::Item(req)); self.messages.push_back(DispatcherMessage::Item(req));
} }
@@ -603,12 +587,12 @@ where
} }
/// keep-alive timer /// keep-alive timer
fn poll_keepalive(&mut self, cx: &mut Context) -> Result<(), DispatchError> { fn poll_keepalive(&mut self) -> Result<(), DispatchError> {
if self.ka_timer.is_none() { if self.ka_timer.is_none() {
// shutdown timeout // shutdown timeout
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if let Some(interval) = self.codec.config().client_disconnect_timer() { if let Some(interval) = self.codec.config().client_disconnect_timer() {
self.ka_timer = Some(delay(interval)); self.ka_timer = Some(Delay::new(interval));
} else { } else {
self.flags.insert(Flags::READ_DISCONNECT); self.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
@@ -621,8 +605,11 @@ where
} }
} }
match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { match self.ka_timer.as_mut().unwrap().poll().map_err(|e| {
Poll::Ready(()) => { error!("Timer error {:?}", e);
DispatchError::Unknown
})? {
Async::Ready(_) => {
// if we get timeout during shutdown, drop connection // if we get timeout during shutdown, drop connection
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
@@ -637,9 +624,9 @@ where
if let Some(deadline) = if let Some(deadline) =
self.codec.config().client_disconnect_timer() self.codec.config().client_disconnect_timer()
{ {
if let Some(mut timer) = self.ka_timer.as_mut() { if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = Pin::new(&mut timer).poll(cx); let _ = timer.poll();
} }
} else { } else {
// no shutdown timeout, drop socket // no shutdown timeout, drop socket
@@ -663,37 +650,23 @@ where
} else if let Some(deadline) = } else if let Some(deadline) =
self.codec.config().keep_alive_expire() self.codec.config().keep_alive_expire()
{ {
if let Some(mut timer) = self.ka_timer.as_mut() { if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = Pin::new(&mut timer).poll(cx); let _ = timer.poll();
} }
} }
} else if let Some(mut timer) = self.ka_timer.as_mut() { } else if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(self.ka_expire); timer.reset(self.ka_expire);
let _ = Pin::new(&mut timer).poll(cx); let _ = timer.poll();
} }
} }
Poll::Pending => (), Async::NotReady => (),
} }
Ok(()) Ok(())
} }
} }
impl<T, S, B, X, U> Unpin for Dispatcher<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where where
T: IoStream, T: IoStream,
@@ -706,28 +679,27 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.as_mut().inner { match self.inner {
DispatcherState::Normal(ref mut inner) => { DispatcherState::Normal(ref mut inner) => {
inner.poll_keepalive(cx)?; inner.poll_keepalive()?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
// flush buffer // flush buffer
inner.poll_flush(cx)?; inner.poll_flush()?;
if !inner.write_buf.is_empty() { if !inner.write_buf.is_empty() {
Poll::Pending Ok(Async::NotReady)
} else { } else {
match Pin::new(&mut inner.io).poll_shutdown(cx) { match inner.io.shutdown()? {
Poll::Ready(res) => { Async::Ready(_) => Ok(Async::Ready(())),
Poll::Ready(res.map_err(DispatchError::from)) Async::NotReady => Ok(Async::NotReady),
}
Poll::Pending => Poll::Pending,
} }
} }
} }
@@ -735,12 +707,12 @@ where
// read socket into a buf // read socket into a buf
let should_disconnect = let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) { if !inner.flags.contains(Flags::READ_DISCONNECT) {
read_available(cx, &mut inner.io, &mut inner.read_buf)? read_available(&mut inner.io, &mut inner.read_buf)?
} else { } else {
None None
}; };
inner.poll_request(cx)?; inner.poll_request()?;
if let Some(true) = should_disconnect { if let Some(true) = should_disconnect {
inner.flags.insert(Flags::READ_DISCONNECT); inner.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = inner.payload.take() { if let Some(mut payload) = inner.payload.take() {
@@ -752,7 +724,7 @@ where
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE); inner.write_buf.reserve(HW_BUFFER_SIZE);
} }
let result = inner.poll_response(cx)?; let result = inner.poll_response()?;
let drain = result == PollResponse::DrainWriteBuf; let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler // switch to upgrade handler
@@ -770,7 +742,7 @@ where
self.inner = DispatcherState::Upgrade( self.inner = DispatcherState::Upgrade(
inner.upgrade.unwrap().call((req, framed)), inner.upgrade.unwrap().call((req, framed)),
); );
return self.poll(cx); return self.poll();
} else { } else {
panic!() panic!()
} }
@@ -779,14 +751,14 @@ where
// we didnt get WouldBlock from write operation, // we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX) // so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck // and we have to write again otherwise response can get stuck
if inner.poll_flush(cx)? || !drain { if inner.poll_flush()? || !drain {
break; break;
} }
} }
// client is gone // client is gone
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
return Poll::Ready(Ok(())); return Ok(Async::Ready(()));
} }
let is_empty = inner.state.is_empty(); let is_empty = inner.state.is_empty();
@@ -799,44 +771,38 @@ where
// keep-alive and stream errors // keep-alive and stream errors
if is_empty && inner.write_buf.is_empty() { if is_empty && inner.write_buf.is_empty() {
if let Some(err) = inner.error.take() { if let Some(err) = inner.error.take() {
Poll::Ready(Err(err)) Err(err)
} }
// disconnect if keep-alive is not enabled // disconnect if keep-alive is not enabled
else if inner.flags.contains(Flags::STARTED) else if inner.flags.contains(Flags::STARTED)
&& !inner.flags.intersects(Flags::KEEPALIVE) && !inner.flags.intersects(Flags::KEEPALIVE)
{ {
inner.flags.insert(Flags::SHUTDOWN); inner.flags.insert(Flags::SHUTDOWN);
self.poll(cx) self.poll()
} }
// disconnect if shutdown // disconnect if shutdown
else if inner.flags.contains(Flags::SHUTDOWN) { else if inner.flags.contains(Flags::SHUTDOWN) {
self.poll(cx) self.poll()
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
DispatcherState::Upgrade(ref mut fut) => { DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| { error!("Upgrade handler error: {}", e);
error!("Upgrade handler error: {}", e); DispatchError::Upgrade
DispatchError::Upgrade }),
})
}
DispatcherState::None => panic!(), DispatcherState::None => panic!(),
} }
} }
} }
fn read_available<T>( fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>
cx: &mut Context,
io: &mut T,
buf: &mut BytesMut,
) -> Result<Option<bool>, io::Error>
where where
T: AsyncRead + Unpin, T: io::Read,
{ {
let mut read_some = false; let mut read_some = false;
loop { loop {
@@ -844,18 +810,19 @@ where
buf.reserve(HW_BUFFER_SIZE); buf.reserve(HW_BUFFER_SIZE);
} }
match read(cx, io, buf) { let read = unsafe { io.read(buf.bytes_mut()) };
Poll::Pending => { match read {
return if read_some { Ok(Some(false)) } else { Ok(None) }; Ok(n) => {
}
Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
return Ok(Some(true)); return Ok(Some(true));
} else { } else {
read_some = true; read_some = true;
unsafe {
buf.advance_mut(n);
}
} }
} }
Poll::Ready(Err(e)) => { Err(e) => {
return if e.kind() == io::ErrorKind::WouldBlock { return if e.kind() == io::ErrorKind::WouldBlock {
if read_some { if read_some {
Ok(Some(false)) Ok(Some(false))
@@ -866,23 +833,12 @@ where
Ok(Some(true)) Ok(Some(true))
} else { } else {
Err(e) Err(e)
} };
} }
} }
} }
} }
fn read<T>(
cx: &mut Context,
io: &mut T,
buf: &mut BytesMut,
) -> Poll<Result<usize, io::Error>>
where
T: AsyncRead + Unpin,
{
Pin::new(io).poll_read_buf(cx, buf)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::IntoService; use actix_service::IntoService;
@@ -896,7 +852,7 @@ mod tests {
#[test] #[test]
fn test_req_parse_err() { fn test_req_parse_err() {
let mut sys = actix_rt::System::new("test"); let mut sys = actix_rt::System::new("test");
let _ = sys.block_on(lazy(|cx| { let _ = sys.block_on(lazy(|| {
let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new( let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new(
@@ -909,10 +865,7 @@ mod tests {
None, None,
None, None,
); );
match Pin::new(&mut h1).poll(cx) { assert!(h1.poll().is_err());
Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()),
}
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal(ref inner) = h1.inner {
assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert!(inner.flags.contains(Flags::READ_DISCONNECT));

View File

@@ -1,24 +1,21 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult};
use futures::{Async, Poll};
use crate::error::Error; use crate::error::Error;
use crate::request::Request; use crate::request::Request;
pub struct ExpectHandler; pub struct ExpectHandler;
impl ServiceFactory for ExpectHandler { impl NewService for ExpectHandler {
type Config = ServerConfig; type Config = ServerConfig;
type Request = Request; type Request = Request;
type Response = Request; type Response = Request;
type Error = Error; type Error = Error;
type Service = ExpectHandler; type Service = ExpectHandler;
type InitError = Error; type InitError = Error;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &ServerConfig) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
ok(ExpectHandler) ok(ExpectHandler)
@@ -29,10 +26,10 @@ impl Service for ExpectHandler {
type Request = Request; type Request = Request;
type Response = Request; type Response = Request;
type Error = Error; type Error = Error;
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: Request) -> Self::Future { fn call(&mut self, req: Request) -> Self::Future {

View File

@@ -1,14 +1,12 @@
//! Payload stream //! Payload stream
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::rc::{Rc, Weak}; use std::rc::{Rc, Weak};
use std::task::{Context, Poll};
use actix_utils::task::LocalWaker;
use bytes::Bytes; use bytes::Bytes;
use futures::Stream; use futures::task::current as current_task;
use futures::task::Task;
use futures::{Async, Poll, Stream};
use crate::error::PayloadError; use crate::error::PayloadError;
@@ -79,24 +77,15 @@ impl Payload {
pub fn unread_data(&mut self, data: Bytes) { pub fn unread_data(&mut self, data: Bytes) {
self.inner.borrow_mut().unread_data(data); self.inner.borrow_mut().unread_data(data);
} }
#[inline]
pub fn readany(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
}
} }
impl Stream for Payload { impl Stream for Payload {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next( #[inline]
self: Pin<&mut Self>, fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> {
cx: &mut Context, self.inner.borrow_mut().readany()
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
} }
} }
@@ -128,14 +117,19 @@ impl PayloadSender {
} }
#[inline] #[inline]
pub fn need_read(&self, cx: &mut Context) -> PayloadStatus { pub fn need_read(&self) -> PayloadStatus {
// we check need_read only if Payload (other side) is alive, // we check need_read only if Payload (other side) is alive,
// otherwise always return true (consume payload) // otherwise always return true (consume payload)
if let Some(shared) = self.inner.upgrade() { if let Some(shared) = self.inner.upgrade() {
if shared.borrow().need_read { if shared.borrow().need_read {
PayloadStatus::Read PayloadStatus::Read
} else { } else {
shared.borrow_mut().io_task.register(cx.waker()); #[cfg(not(test))]
{
if shared.borrow_mut().io_task.is_none() {
shared.borrow_mut().io_task = Some(current_task());
}
}
PayloadStatus::Pause PayloadStatus::Pause
} }
} else { } else {
@@ -151,8 +145,8 @@ struct Inner {
err: Option<PayloadError>, err: Option<PayloadError>,
need_read: bool, need_read: bool,
items: VecDeque<Bytes>, items: VecDeque<Bytes>,
task: LocalWaker, task: Option<Task>,
io_task: LocalWaker, io_task: Option<Task>,
} }
impl Inner { impl Inner {
@@ -163,8 +157,8 @@ impl Inner {
err: None, err: None,
items: VecDeque::new(), items: VecDeque::new(),
need_read: true, need_read: true,
task: LocalWaker::new(), task: None,
io_task: LocalWaker::new(), io_task: None,
} }
} }
@@ -184,7 +178,7 @@ impl Inner {
self.items.push_back(data); self.items.push_back(data);
self.need_read = self.len < MAX_BUFFER_SIZE; self.need_read = self.len < MAX_BUFFER_SIZE;
if let Some(task) = self.task.take() { if let Some(task) = self.task.take() {
task.wake() task.notify()
} }
} }
@@ -193,28 +187,34 @@ impl Inner {
self.len self.len
} }
fn readany( fn readany(&mut self) -> Poll<Option<Bytes>, PayloadError> {
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
if let Some(data) = self.items.pop_front() { if let Some(data) = self.items.pop_front() {
self.len -= data.len(); self.len -= data.len();
self.need_read = self.len < MAX_BUFFER_SIZE; self.need_read = self.len < MAX_BUFFER_SIZE;
if self.need_read && !self.eof { if self.need_read && self.task.is_none() && !self.eof {
self.task.register(cx.waker()); self.task = Some(current_task());
} }
self.io_task.wake(); if let Some(task) = self.io_task.take() {
Poll::Ready(Some(Ok(data))) task.notify()
}
Ok(Async::Ready(Some(data)))
} else if let Some(err) = self.err.take() { } else if let Some(err) = self.err.take() {
Poll::Ready(Some(Err(err))) Err(err)
} else if self.eof { } else if self.eof {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
self.need_read = true; self.need_read = true;
self.task.register(cx.waker()); #[cfg(not(test))]
self.io_task.wake(); {
Poll::Pending if self.task.is_none() {
self.task = Some(current_task());
}
if let Some(task) = self.io_task.take() {
task.notify()
}
}
Ok(Async::NotReady)
} }
} }
@@ -228,23 +228,27 @@ impl Inner {
mod tests { mod tests {
use super::*; use super::*;
use actix_rt::Runtime; use actix_rt::Runtime;
use futures::future::{poll_fn, ready}; use futures::future::{lazy, result};
#[test] #[test]
fn test_unread_data() { fn test_unread_data() {
Runtime::new().unwrap().block_on(async { Runtime::new()
let (_, mut payload) = Payload::create(false); .unwrap()
.block_on(lazy(|| {
let (_, mut payload) = Payload::create(false);
payload.unread_data(Bytes::from("data")); payload.unread_data(Bytes::from("data"));
assert!(!payload.is_empty()); assert!(!payload.is_empty());
assert_eq!(payload.len(), 4); assert_eq!(payload.len(), 4);
assert_eq!( assert_eq!(
Bytes::from("data"), Async::Ready(Some(Bytes::from("data"))),
poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap() payload.poll().ok().unwrap()
); );
ready(()) let res: Result<(), ()> = Ok(());
}); result(res)
}))
.unwrap();
} }
} }

View File

@@ -1,15 +1,12 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::Framed; use actix_codec::Framed;
use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_service::{IntoNewService, NewService, Service};
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult};
use futures::{ready, Stream}; use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
@@ -23,7 +20,7 @@ use super::codec::Codec;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
use super::{ExpectHandler, Message, UpgradeHandler}; use super::{ExpectHandler, Message, UpgradeHandler};
/// `ServiceFactory` implementation for HTTP1 transport /// `NewService` implementation for HTTP1 transport
pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> { pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
@@ -35,19 +32,19 @@ pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, P, S, B> H1Service<T, P, S, B> impl<T, P, S, B> H1Service<T, P, S, B>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
{ {
/// Create new `HttpService` instance with default config. /// Create new `HttpService` instance with default config.
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self { pub fn new<F: IntoNewService<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
H1Service { H1Service {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -56,13 +53,10 @@ where
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub fn with_config<F: IntoServiceFactory<S>>( pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self {
cfg: ServiceConfig,
service: F,
) -> Self {
H1Service { H1Service {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -73,7 +67,7 @@ where
impl<T, P, S, B, X, U> H1Service<T, P, S, B, X, U> impl<T, P, S, B, X, U> H1Service<T, P, S, B, X, U>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
@@ -81,7 +75,7 @@ where
{ {
pub fn expect<X1>(self, expect: X1) -> H1Service<T, P, S, B, X1, U> pub fn expect<X1>(self, expect: X1) -> H1Service<T, P, S, B, X1, U>
where where
X1: ServiceFactory<Request = Request, Response = Request>, X1: NewService<Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
{ {
@@ -97,7 +91,7 @@ where
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, P, S, B, X, U1> pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, P, S, B, X, U1>
where where
U1: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>, U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
{ {
@@ -121,18 +115,18 @@ where
} }
} }
impl<T, P, S, B, X, U> ServiceFactory for H1Service<T, P, S, B, X, U> impl<T, P, S, B, X, U> NewService for H1Service<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>, X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory< U: NewService<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
@@ -150,7 +144,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H1ServiceResponse { H1ServiceResponse {
fut: self.srv.new_service(cfg), fut: self.srv.new_service(cfg).into_future(),
fut_ex: Some(self.expect.new_service(cfg)), fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None, expect: None,
@@ -163,24 +157,20 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project]
pub struct H1ServiceResponse<T, P, S, B, X, U> pub struct H1ServiceResponse<T, P, S, B, X, U>
where where
S: ServiceFactory<Request = Request>, S: NewService<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
X: ServiceFactory<Request = Request, Response = Request>, X: NewService<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>, U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
#[pin]
fut: S::Future, fut: S::Future,
#[pin]
fut_ex: Option<X::Future>, fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
@@ -192,57 +182,49 @@ where
impl<T, P, S, B, X, U> Future for H1ServiceResponse<T, P, S, B, X, U> impl<T, P, S, B, X, U> Future for H1ServiceResponse<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: ServiceFactory<Request = Request>, S: NewService<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
X: ServiceFactory<Request = Request, Response = Request>, X: NewService<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>, U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
type Output = type Item = H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
Result<H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>, ()>; type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut this = self.as_mut().project(); if let Some(ref mut fut) = self.fut_ex {
let expect = try_ready!(fut
if let Some(fut) = this.fut_ex.as_pin_mut() { .poll()
let expect = ready!(fut .map_err(|e| log::error!("Init http service error: {:?}", e)));
.poll(cx) self.expect = Some(expect);
.map_err(|e| log::error!("Init http service error: {:?}", e)))?; self.fut_ex.take();
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
} }
if let Some(fut) = this.fut_upg.as_pin_mut() { if let Some(ref mut fut) = self.fut_upg {
let upgrade = ready!(fut let upgrade = try_ready!(fut
.poll(cx) .poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)))?; .map_err(|e| log::error!("Init http service error: {:?}", e)));
this = self.as_mut().project(); self.upgrade = Some(upgrade);
*this.upgrade = Some(upgrade); self.fut_ex.take();
this.fut_ex.set(None);
} }
let result = ready!(this let service = try_ready!(self
.fut .fut
.poll(cx) .poll()
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(H1ServiceHandler::new(
Poll::Ready(result.map(|service| { self.cfg.take().unwrap(),
let this = self.as_mut().project(); service,
H1ServiceHandler::new( self.expect.take().unwrap(),
this.cfg.take().unwrap(), self.upgrade.take(),
service, self.on_connect.clone(),
this.expect.take().unwrap(), )))
this.upgrade.take(),
this.on_connect.clone(),
)
}))
} }
} }
@@ -302,10 +284,10 @@ where
type Error = DispatchError; type Error = DispatchError;
type Future = Dispatcher<T, S, B, X, U>; type Future = Dispatcher<T, S, B, X, U>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
let ready = self let ready = self
.expect .expect
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -315,7 +297,7 @@ where
let ready = self let ready = self
.srv .srv
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -325,9 +307,9 @@ where
&& ready; && ready;
if ready { if ready {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
@@ -351,7 +333,7 @@ where
} }
} }
/// `ServiceFactory` implementation for `OneRequestService` service /// `NewService` implementation for `OneRequestService` service
#[derive(Default)] #[derive(Default)]
pub struct OneRequest<T, P> { pub struct OneRequest<T, P> {
config: ServiceConfig, config: ServiceConfig,
@@ -371,7 +353,7 @@ where
} }
} }
impl<T, P> ServiceFactory for OneRequest<T, P> impl<T, P> NewService for OneRequest<T, P>
where where
T: IoStream, T: IoStream,
{ {
@@ -381,7 +363,7 @@ where
type Error = ParseError; type Error = ParseError;
type InitError = (); type InitError = ();
type Service = OneRequestService<T, P>; type Service = OneRequestService<T, P>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &SrvConfig) -> Self::Future { fn new_service(&self, _: &SrvConfig) -> Self::Future {
ok(OneRequestService { ok(OneRequestService {
@@ -407,8 +389,8 @@ where
type Error = ParseError; type Error = ParseError;
type Future = OneRequestServiceResponse<T>; type Future = OneRequestServiceResponse<T>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
@@ -433,19 +415,19 @@ impl<T> Future for OneRequestServiceResponse<T>
where where
T: IoStream, T: IoStream,
{ {
type Output = Result<(Request, Framed<T, Codec>), ParseError>; type Item = (Request, Framed<T, Codec>);
type Error = ParseError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.framed.as_mut().unwrap().next_item(cx) { match self.framed.as_mut().unwrap().poll()? {
Poll::Ready(Some(Ok(req))) => match req { Async::Ready(Some(req)) => match req {
Message::Item(req) => { Message::Item(req) => {
Poll::Ready(Ok((req, self.framed.take().unwrap()))) Ok(Async::Ready((req, self.framed.take().unwrap())))
} }
Message::Chunk(_) => unreachable!("Something is wrong"), Message::Chunk(_) => unreachable!("Something is wrong"),
}, },
Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)), Async::Ready(None) => Err(ParseError::Incomplete),
Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)), Async::NotReady => Ok(Async::NotReady),
Poll::Pending => Poll::Pending,
} }
} }
} }

View File

@@ -1,12 +1,10 @@
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::Framed; use actix_codec::Framed;
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::Ready; use futures::future::FutureResult;
use futures::{Async, Poll};
use crate::error::Error; use crate::error::Error;
use crate::h1::Codec; use crate::h1::Codec;
@@ -14,14 +12,14 @@ use crate::request::Request;
pub struct UpgradeHandler<T>(PhantomData<T>); pub struct UpgradeHandler<T>(PhantomData<T>);
impl<T> ServiceFactory for UpgradeHandler<T> { impl<T> NewService for UpgradeHandler<T> {
type Config = ServerConfig; type Config = ServerConfig;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Service = UpgradeHandler<T>; type Service = UpgradeHandler<T>;
type InitError = Error; type InitError = Error;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &ServerConfig) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
unimplemented!() unimplemented!()
@@ -32,10 +30,10 @@ impl<T> Service for UpgradeHandler<T> {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, _: Self::Request) -> Self::Future { fn call(&mut self, _: Self::Request) -> Self::Future {

View File

@@ -1,9 +1,5 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use futures::Sink; use futures::{Async, Future, Poll, Sink};
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::error::Error; use crate::error::Error;
@@ -11,7 +7,6 @@ use crate::h1::{Codec, Message};
use crate::response::Response; use crate::response::Response;
/// Send http/1 response /// Send http/1 response
#[pin_project::pin_project]
pub struct SendResponse<T, B> { pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
body: Option<ResponseBody<B>>, body: Option<ResponseBody<B>>,
@@ -38,61 +33,60 @@ where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
B: MessageBody, B: MessageBody,
{ {
type Output = Result<Framed<T, Codec>, Error>; type Item = Framed<T, Codec>;
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
let mut body_ready = this.body.is_some(); let mut body_ready = self.body.is_some();
let framed = this.framed.as_mut().unwrap(); let framed = self.framed.as_mut().unwrap();
// send body // send body
if this.res.is_none() && this.body.is_some() { if self.res.is_none() && self.body.is_some() {
while body_ready && this.body.is_some() && !framed.is_write_buf_full() { while body_ready && self.body.is_some() && !framed.is_write_buf_full() {
match this.body.as_mut().unwrap().poll_next(cx)? { match self.body.as_mut().unwrap().poll_next()? {
Poll::Ready(item) => { Async::Ready(item) => {
// body is done // body is done
if item.is_none() { if item.is_none() {
let _ = this.body.take(); let _ = self.body.take();
} }
framed.write(Message::Chunk(item))?; framed.force_send(Message::Chunk(item))?;
} }
Poll::Pending => body_ready = false, Async::NotReady => body_ready = false,
} }
} }
} }
// flush write buffer // flush write buffer
if !framed.is_write_buf_empty() { if !framed.is_write_buf_empty() {
match framed.flush(cx)? { match framed.poll_complete()? {
Poll::Ready(_) => { Async::Ready(_) => {
if body_ready { if body_ready {
continue; continue;
} else { } else {
return Poll::Pending; return Ok(Async::NotReady);
} }
} }
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
} }
} }
// send response // send response
if let Some(res) = this.res.take() { if let Some(res) = self.res.take() {
framed.write(res)?; framed.force_send(res)?;
continue; continue;
} }
if this.body.is_some() { if self.body.is_some() {
if body_ready { if body_ready {
continue; continue;
} else { } else {
return Poll::Pending; return Ok(Async::NotReady);
} }
} else { } else {
break; break;
} }
} }
Poll::Ready(Ok(this.framed.take().unwrap())) Ok(Async::Ready(self.framed.take().unwrap()))
} }
} }

View File

@@ -1,8 +1,5 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant; use std::time::Instant;
use std::{fmt, mem, net}; use std::{fmt, mem, net};
@@ -11,7 +8,7 @@ use actix_server_config::IoStream;
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{ready, Sink, Stream}; use futures::{try_ready, Async, Future, Poll, Sink, Stream};
use h2::server::{Connection, SendResponse}; use h2::server::{Connection, SendResponse};
use h2::{RecvStream, SendStream}; use h2::{RecvStream, SendStream};
use http::header::{ use http::header::{
@@ -35,7 +32,6 @@ use crate::response::Response;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
/// Dispatcher for HTTP/2 protocol /// Dispatcher for HTTP/2 protocol
#[pin_project::pin_project]
pub struct Dispatcher<T: IoStream, S: Service<Request = Request>, B: MessageBody> { pub struct Dispatcher<T: IoStream, S: Service<Request = Request>, B: MessageBody> {
service: CloneableService<S>, service: CloneableService<S>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
@@ -52,9 +48,9 @@ where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
// S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody + 'static,
{ {
pub(crate) fn new( pub(crate) fn new(
service: CloneableService<S>, service: CloneableService<S>,
@@ -97,75 +93,61 @@ impl<T, S, B> Future for Dispatcher<T, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.get_mut();
loop { loop {
match Pin::new(&mut this.connection).poll_accept(cx) { match self.connection.poll()? {
Poll::Ready(None) => return Poll::Ready(Ok(())), Async::Ready(None) => return Ok(Async::Ready(())),
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), Async::Ready(Some((req, res))) => {
Poll::Ready(Some(Ok((req, res)))) => {
// update keep-alive expire // update keep-alive expire
if this.ka_timer.is_some() { if self.ka_timer.is_some() {
if let Some(expire) = this.config.keep_alive_expire() { if let Some(expire) = self.config.keep_alive_expire() {
this.ka_expire = expire; self.ka_expire = expire;
} }
} }
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
let mut req = Request::with_payload(Payload::< let mut req = Request::with_payload(body.into());
crate::payload::PayloadStream,
>::H2(
crate::h2::Payload::new(body)
));
let head = &mut req.head_mut(); let head = &mut req.head_mut();
head.uri = parts.uri; head.uri = parts.uri;
head.method = parts.method; head.method = parts.method;
head.version = parts.version; head.version = parts.version;
head.headers = parts.headers.into(); head.headers = parts.headers.into();
head.peer_addr = this.peer_addr; head.peer_addr = self.peer_addr;
// set on_connect data // set on_connect data
if let Some(ref on_connect) = this.on_connect { if let Some(ref on_connect) = self.on_connect {
on_connect.set(&mut req.extensions_mut()); on_connect.set(&mut req.extensions_mut());
} }
tokio_executor::current_thread::spawn(ServiceResponse::< tokio_current_thread::spawn(ServiceResponse::<S::Future, B> {
S::Future,
S::Response,
S::Error,
B,
> {
state: ServiceResponseState::ServiceCall( state: ServiceResponseState::ServiceCall(
this.service.call(req), self.service.call(req),
Some(res), Some(res),
), ),
config: this.config.clone(), config: self.config.clone(),
buffer: None, buffer: None,
_t: PhantomData, })
});
} }
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
} }
} }
} }
} }
#[pin_project::pin_project] struct ServiceResponse<F, B> {
struct ServiceResponse<F, I, E, B> {
state: ServiceResponseState<F, B>, state: ServiceResponseState<F, B>,
config: ServiceConfig, config: ServiceConfig,
buffer: Option<Bytes>, buffer: Option<Bytes>,
_t: PhantomData<(I, E)>,
} }
enum ServiceResponseState<F, B> { enum ServiceResponseState<F, B> {
@@ -173,12 +155,12 @@ enum ServiceResponseState<F, B> {
SendPayload(SendStream<Bytes>, ResponseBody<B>), SendPayload(SendStream<Bytes>, ResponseBody<B>),
} }
impl<F, I, E, B> ServiceResponse<F, I, E, B> impl<F, B> ServiceResponse<F, B>
where where
F: Future<Output = Result<I, E>>, F: Future,
E: Into<Error>, F::Error: Into<Error>,
I: Into<Response<B>>, F::Item: Into<Response<B>>,
B: MessageBody, B: MessageBody + 'static,
{ {
fn prepare_response( fn prepare_response(
&self, &self,
@@ -241,121 +223,109 @@ where
} }
} }
impl<F, I, E, B> Future for ServiceResponse<F, I, E, B> impl<F, B> Future for ServiceResponse<F, B>
where where
F: Future<Output = Result<I, E>>, F: Future,
E: Into<Error>, F::Error: Into<Error>,
I: Into<Response<B>>, F::Item: Into<Response<B>>,
B: MessageBody, B: MessageBody + 'static,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut this = self.as_mut().project(); match self.state {
match this.state {
ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { ServiceResponseState::ServiceCall(ref mut call, ref mut send) => {
match unsafe { Pin::new_unchecked(call) }.poll(cx) { match call.poll() {
Poll::Ready(Ok(res)) => { Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
let mut send = send.take().unwrap(); let mut send = send.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = let h2_res = self.prepare_response(res.head(), &mut size);
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = match send.send_response(h2_res, size.is_eof()) { let stream =
Err(e) => { send.send_response(h2_res, size.is_eof()).map_err(|e| {
trace!("Error sending h2 response: {:?}", e); trace!("Error sending h2 response: {:?}", e);
return Poll::Ready(()); })?;
}
Ok(stream) => stream,
};
if size.is_eof() { if size.is_eof() {
Poll::Ready(()) Ok(Async::Ready(()))
} else { } else {
*this.state = self.state = ServiceResponseState::SendPayload(stream, body);
ServiceResponseState::SendPayload(stream, body); self.poll()
self.poll(cx)
} }
} }
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(Err(e)) => { Err(e) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
let mut send = send.take().unwrap(); let mut send = send.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = let h2_res = self.prepare_response(res.head(), &mut size);
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = match send.send_response(h2_res, size.is_eof()) { let stream =
Err(e) => { send.send_response(h2_res, size.is_eof()).map_err(|e| {
trace!("Error sending h2 response: {:?}", e); trace!("Error sending h2 response: {:?}", e);
return Poll::Ready(()); })?;
}
Ok(stream) => stream,
};
if size.is_eof() { if size.is_eof() {
Poll::Ready(()) Ok(Async::Ready(()))
} else { } else {
*this.state = ServiceResponseState::SendPayload( self.state = ServiceResponseState::SendPayload(
stream, stream,
body.into_body(), body.into_body(),
); );
self.poll(cx) self.poll()
} }
} }
} }
} }
ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop {
loop { loop {
if let Some(ref mut buffer) = this.buffer { if let Some(ref mut buffer) = self.buffer {
match stream.poll_capacity(cx) { match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? {
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
Poll::Ready(None) => return Poll::Ready(()), Async::Ready(None) => return Ok(Async::Ready(())),
Poll::Ready(Some(Ok(cap))) => { Async::Ready(Some(cap)) => {
let len = buffer.len(); let len = buffer.len();
let bytes = buffer.split_to(std::cmp::min(cap, len)); let bytes = buffer.split_to(std::cmp::min(cap, len));
if let Err(e) = stream.send_data(bytes, false) { if let Err(e) = stream.send_data(bytes, false) {
warn!("{:?}", e); warn!("{:?}", e);
return Poll::Ready(()); return Err(());
} else if !buffer.is_empty() { } else if !buffer.is_empty() {
let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); let cap = std::cmp::min(buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap); stream.reserve_capacity(cap);
} else { } else {
this.buffer.take(); self.buffer.take();
} }
} }
Poll::Ready(Some(Err(e))) => {
warn!("{:?}", e);
return Poll::Ready(());
}
} }
} else { } else {
match body.poll_next(cx) { match body.poll_next() {
Poll::Pending => return Poll::Pending, Ok(Async::NotReady) => {
Poll::Ready(None) => { return Ok(Async::NotReady);
}
Ok(Async::Ready(None)) => {
if let Err(e) = stream.send_data(Bytes::new(), true) { if let Err(e) = stream.send_data(Bytes::new(), true) {
warn!("{:?}", e); warn!("{:?}", e);
return Err(());
} else {
return Ok(Async::Ready(()));
} }
return Poll::Ready(());
} }
Poll::Ready(Some(Ok(chunk))) => { Ok(Async::Ready(Some(chunk))) => {
stream.reserve_capacity(std::cmp::min( stream.reserve_capacity(std::cmp::min(
chunk.len(), chunk.len(),
CHUNK_SIZE, CHUNK_SIZE,
)); ));
*this.buffer = Some(chunk); self.buffer = Some(chunk);
} }
Poll::Ready(Some(Err(e))) => { Err(e) => {
error!("Response payload stream error: {:?}", e); error!("Response payload stream error: {:?}", e);
return Poll::Ready(()); return Err(());
} }
} }
} }

View File

@@ -1,11 +1,9 @@
#![allow(dead_code, unused_imports)] #![allow(dead_code, unused_imports)]
use std::fmt; use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
use futures::Stream; use futures::{Async, Poll, Stream};
use h2::RecvStream; use h2::RecvStream;
mod dispatcher; mod dispatcher;
@@ -27,23 +25,22 @@ impl Payload {
} }
impl Stream for Payload { impl Stream for Payload {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let this = self.get_mut(); match self.pl.poll() {
Ok(Async::Ready(Some(chunk))) => {
match Pin::new(&mut this.pl).poll_data(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let len = chunk.len(); let len = chunk.len();
if let Err(err) = this.pl.release_capacity().release_capacity(len) { if let Err(err) = self.pl.release_capacity().release_capacity(len) {
Poll::Ready(Some(Err(err.into()))) Err(err.into())
} else { } else {
Poll::Ready(Some(Ok(chunk))) Ok(Async::Ready(Some(chunk)))
} }
} }
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(None) => Poll::Ready(None), Err(err) => Err(err.into()),
} }
} }
} }

View File

@@ -1,16 +1,13 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, net, rc}; use std::{io, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_service::{IntoNewService, NewService, Service};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult};
use futures::{ready, Stream}; use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream};
use h2::server::{self, Connection, Handshake}; use h2::server::{self, Connection, Handshake};
use h2::RecvStream; use h2::RecvStream;
use log::error; use log::error;
@@ -26,7 +23,7 @@ use crate::response::Response;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
/// `ServiceFactory` implementation for HTTP2 transport /// `NewService` implementation for HTTP2 transport
pub struct H2Service<T, P, S, B> { pub struct H2Service<T, P, S, B> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
@@ -36,33 +33,30 @@ pub struct H2Service<T, P, S, B> {
impl<T, P, S, B> H2Service<T, P, S, B> impl<T, P, S, B> H2Service<T, P, S, B>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create new `HttpService` instance. /// Create new `HttpService` instance.
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self { pub fn new<F: IntoNewService<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
H2Service { H2Service {
cfg, cfg,
on_connect: None, on_connect: None,
srv: service.into_factory(), srv: service.into_new_service(),
_t: PhantomData, _t: PhantomData,
} }
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub fn with_config<F: IntoServiceFactory<S>>( pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self {
cfg: ServiceConfig,
service: F,
) -> Self {
H2Service { H2Service {
cfg, cfg,
on_connect: None, on_connect: None,
srv: service.into_factory(), srv: service.into_new_service(),
_t: PhantomData, _t: PhantomData,
} }
} }
@@ -77,12 +71,12 @@ where
} }
} }
impl<T, P, S, B> ServiceFactory for H2Service<T, P, S, B> impl<T, P, S, B> NewService for H2Service<T, P, S, B>
where where
T: IoStream, T: IoStream,
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
@@ -96,7 +90,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H2ServiceResponse { H2ServiceResponse {
fut: self.srv.new_service(cfg), fut: self.srv.new_service(cfg).into_future(),
cfg: Some(self.cfg.clone()), cfg: Some(self.cfg.clone()),
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
_t: PhantomData, _t: PhantomData,
@@ -105,10 +99,8 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project] pub struct H2ServiceResponse<T, P, S: NewService, B> {
pub struct H2ServiceResponse<T, P, S: ServiceFactory, B> { fut: <S::Future as IntoFuture>::Future,
#[pin]
fut: S::Future,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, P, B)>, _t: PhantomData<(T, P, B)>,
@@ -117,25 +109,22 @@ pub struct H2ServiceResponse<T, P, S: ServiceFactory, B> {
impl<T, P, S, B> Future for H2ServiceResponse<T, P, S, B> impl<T, P, S, B> Future for H2ServiceResponse<T, P, S, B>
where where
T: IoStream, T: IoStream,
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Output = Result<H2ServiceHandler<T, P, S::Service, B>, S::InitError>; type Item = H2ServiceHandler<T, P, S::Service, B>;
type Error = S::InitError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.as_mut().project(); let service = try_ready!(self.fut.poll());
Ok(Async::Ready(H2ServiceHandler::new(
Poll::Ready(ready!(this.fut.poll(cx)).map(|service| { self.cfg.take().unwrap(),
let this = self.as_mut().project(); self.on_connect.clone(),
H2ServiceHandler::new( service,
this.cfg.take().unwrap(), )))
this.on_connect.clone(),
service,
)
}))
} }
} }
@@ -150,9 +139,9 @@ pub struct H2ServiceHandler<T, P, S, B> {
impl<T, P, S, B> H2ServiceHandler<T, P, S, B> impl<T, P, S, B> H2ServiceHandler<T, P, S, B>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
fn new( fn new(
@@ -173,9 +162,9 @@ impl<T, P, S, B> Service for H2ServiceHandler<T, P, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Request = Io<T, P>; type Request = Io<T, P>;
@@ -183,8 +172,8 @@ where
type Error = DispatchError; type Error = DispatchError;
type Future = H2ServiceHandlerResponse<T, S, B>; type Future = H2ServiceHandlerResponse<T, S, B>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.srv.poll_ready(cx).map_err(|e| { self.srv.poll_ready().map_err(|e| {
let e = e.into(); let e = e.into();
error!("Service readiness error: {:?}", e); error!("Service readiness error: {:?}", e);
DispatchError::Service(e) DispatchError::Service(e)
@@ -230,9 +219,9 @@ pub struct H2ServiceHandlerResponse<T, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
state: State<T, S, B>, state: State<T, S, B>,
@@ -242,24 +231,25 @@ impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.state { match self.state {
State::Incoming(ref mut disp) => Pin::new(disp).poll(cx), State::Incoming(ref mut disp) => disp.poll(),
State::Handshake( State::Handshake(
ref mut srv, ref mut srv,
ref mut config, ref mut config,
ref peer_addr, ref peer_addr,
ref mut on_connect, ref mut on_connect,
ref mut handshake, ref mut handshake,
) => match Pin::new(handshake).poll(cx) { ) => match handshake.poll() {
Poll::Ready(Ok(conn)) => { Ok(Async::Ready(conn)) => {
self.state = State::Incoming(Dispatcher::new( self.state = State::Incoming(Dispatcher::new(
srv.take().unwrap(), srv.take().unwrap(),
conn, conn,
@@ -268,13 +258,13 @@ where
None, None,
*peer_addr, *peer_addr,
)); ));
self.poll(cx) self.poll()
} }
Poll::Ready(Err(err)) => { Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
trace!("H2 handshake error: {}", err); trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into())) Err(err.into())
} }
Poll::Pending => Poll::Pending,
}, },
} }
} }

View File

@@ -4,8 +4,7 @@
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::new_without_default, clippy::new_without_default,
clippy::borrow_interior_mutable_const, clippy::borrow_interior_mutable_const,
clippy::write_with_newline, clippy::write_with_newline
unused_imports
)] )]
#[macro_use] #[macro_use]

View File

@@ -388,12 +388,6 @@ impl BoxedResponseHead {
pub fn new(status: StatusCode) -> Self { pub fn new(status: StatusCode) -> Self {
RESPONSE_POOL.with(|p| p.get_message(status)) RESPONSE_POOL.with(|p| p.get_message(status))
} }
pub(crate) fn take(&mut self) -> Self {
BoxedResponseHead {
head: self.head.take(),
}
}
} }
impl std::ops::Deref for BoxedResponseHead { impl std::ops::Deref for BoxedResponseHead {
@@ -412,9 +406,7 @@ impl std::ops::DerefMut for BoxedResponseHead {
impl Drop for BoxedResponseHead { impl Drop for BoxedResponseHead {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(head) = self.head.take() { RESPONSE_POOL.with(|p| p.release(self.head.take().unwrap()))
RESPONSE_POOL.with(move |p| p.release(head))
}
} }
} }

View File

@@ -1,15 +1,11 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
use futures::Stream; use futures::{Async, Poll, Stream};
use h2::RecvStream; use h2::RecvStream;
use crate::error::PayloadError; use crate::error::PayloadError;
/// Type represent boxed payload /// Type represent boxed payload
pub type PayloadStream = Pin<Box<dyn Stream<Item = Result<Bytes, PayloadError>>>>; pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>;
/// Type represent streaming payload /// Type represent streaming payload
pub enum Payload<S = PayloadStream> { pub enum Payload<S = PayloadStream> {
@@ -52,17 +48,18 @@ impl<S> Payload<S> {
impl<S> Stream for Payload<S> impl<S> Stream for Payload<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
#[inline] #[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.get_mut() { match self {
Payload::None => Poll::Ready(None), Payload::None => Ok(Async::Ready(None)),
Payload::H1(ref mut pl) => pl.readany(cx), Payload::H1(ref mut pl) => pl.poll(),
Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx), Payload::H2(ref mut pl) => pl.poll(),
Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx), Payload::Stream(ref mut pl) => pl.poll(),
} }
} }
} }

View File

@@ -80,11 +80,6 @@ impl<P> Request<P> {
) )
} }
/// Get request's payload
pub fn payload(&mut self) -> &mut Payload<P> {
&mut self.payload
}
/// Get request's payload /// Get request's payload
pub fn take_payload(&mut self) -> Payload<P> { pub fn take_payload(&mut self) -> Payload<P> {
std::mem::replace(&mut self.payload, Payload::None) std::mem::replace(&mut self.payload, Payload::None)
@@ -204,6 +199,7 @@ mod tests {
assert_eq!(req.uri().query(), Some("q=1")); assert_eq!(req.uri().query(), Some("q=1"));
let s = format!("{:?}", req); let s = format!("{:?}", req);
println!("T: {:?}", s);
assert!(s.contains("Request HTTP/1.1 GET:/index.html")); assert!(s.contains("Request HTTP/1.1 GET:/index.html"));
} }
} }

View File

@@ -1,14 +1,11 @@
//! Http response //! Http response
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
use std::future::Future;
use std::io::Write; use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, str}; use std::{fmt, str};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult, IntoFuture};
use futures::stream::Stream; use futures::Stream;
use serde::Serialize; use serde::Serialize;
use serde_json; use serde_json;
@@ -283,15 +280,13 @@ impl<B: MessageBody> fmt::Debug for Response<B> {
} }
} }
impl Future for Response { impl IntoFuture for Response {
type Output = Result<Response, Error>; type Item = Response;
type Error = Error;
type Future = FutureResult<Response, Error>;
fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Self::Output> { fn into_future(self) -> Self::Future {
Poll::Ready(Ok(Response { ok(self)
head: self.head.take(),
body: self.body.take_body(),
error: self.error.take(),
}))
} }
} }
@@ -640,7 +635,7 @@ impl ResponseBuilder {
/// `ResponseBuilder` can not be used after this call. /// `ResponseBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> Response pub fn streaming<S, E>(&mut self, stream: S) -> Response
where where
S: Stream<Item = Result<Bytes, E>> + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
self.body(Body::from_message(BodyStream::new(stream))) self.body(Body::from_message(BodyStream::new(stream)))
@@ -762,11 +757,13 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder {
} }
} }
impl Future for ResponseBuilder { impl IntoFuture for ResponseBuilder {
type Output = Result<Response, Error>; type Item = Response;
type Error = Error;
type Future = FutureResult<Response, Error>;
fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Self::Output> { fn into_future(mut self) -> Self::Future {
Poll::Ready(Ok(self.finish())) ok(self.finish())
} }
} }

View File

@@ -1,17 +1,14 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, net, rc}; use std::{fmt, io, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_server_config::{ use actix_server_config::{
Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig, Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig,
}; };
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_service::{IntoNewService, NewService, Service};
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{ready, Future}; use futures::{try_ready, Async, Future, IntoFuture, Poll};
use h2::server::{self, Handshake}; use h2::server::{self, Handshake};
use pin_project::{pin_project, project};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::builder::HttpServiceBuilder; use crate::builder::HttpServiceBuilder;
@@ -23,7 +20,7 @@ use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{h1, h2::Dispatcher}; use crate::{h1, h2::Dispatcher};
/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation /// `NewService` HTTP1.1/HTTP2 transport implementation
pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> { pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
@@ -35,10 +32,10 @@ pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler
impl<T, S, B> HttpService<T, (), S, B> impl<T, S, B> HttpService<T, (), S, B>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
@@ -50,20 +47,20 @@ where
impl<T, P, S, B> HttpService<T, P, S, B> impl<T, P, S, B> HttpService<T, P, S, B>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create new `HttpService` instance. /// Create new `HttpService` instance.
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self { pub fn new<F: IntoNewService<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
HttpService { HttpService {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -72,13 +69,13 @@ where
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoServiceFactory<S>>( pub(crate) fn with_config<F: IntoNewService<S>>(
cfg: ServiceConfig, cfg: ServiceConfig,
service: F, service: F,
) -> Self { ) -> Self {
HttpService { HttpService {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -89,11 +86,10 @@ where
impl<T, P, S, B, X, U> HttpService<T, P, S, B, X, U> impl<T, P, S, B, X, U> HttpService<T, P, S, B, X, U>
where where
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static,
B: MessageBody, B: MessageBody,
{ {
/// Provide service for `EXPECT: 100-Continue` support. /// Provide service for `EXPECT: 100-Continue` support.
@@ -103,10 +99,9 @@ where
/// request will be forwarded to main service. /// request will be forwarded to main service.
pub fn expect<X1>(self, expect: X1) -> HttpService<T, P, S, B, X1, U> pub fn expect<X1>(self, expect: X1) -> HttpService<T, P, S, B, X1, U>
where where
X1: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>, X1: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{ {
HttpService { HttpService {
expect, expect,
@@ -124,14 +119,13 @@ where
/// and this service get called with original request and framed object. /// and this service get called with original request and framed object.
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, P, S, B, X, U1> pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, P, S, B, X, U1>
where where
U1: ServiceFactory< U1: NewService<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>), Request = (Request, Framed<T, h1::Codec>),
Response = (), Response = (),
>, >,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{ {
HttpService { HttpService {
upgrade, upgrade,
@@ -153,27 +147,25 @@ where
} }
} }
impl<T, P, S, B, X, U> ServiceFactory for HttpService<T, P, S, B, X, U> impl<T, P, S, B, X, U> NewService for HttpService<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: ServiceFactory<Config = SrvConfig, Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>, X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static, U: NewService<
U: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>), Request = (Request, Framed<T, h1::Codec>),
Response = (), Response = (),
>, >,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
type Config = SrvConfig; type Config = SrvConfig;
type Request = ServerIo<T, P>; type Request = ServerIo<T, P>;
@@ -185,7 +177,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
HttpServiceResponse { HttpServiceResponse {
fut: self.srv.new_service(cfg), fut: self.srv.new_service(cfg).into_future(),
fut_ex: Some(self.expect.new_service(cfg)), fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None, expect: None,
@@ -198,20 +190,9 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project] pub struct HttpServiceResponse<T, P, S: NewService, B, X: NewService, U: NewService> {
pub struct HttpServiceResponse<
T,
P,
S: ServiceFactory,
B,
X: ServiceFactory,
U: ServiceFactory,
> {
#[pin]
fut: S::Future, fut: S::Future,
#[pin]
fut_ex: Option<X::Future>, fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
@@ -223,59 +204,50 @@ pub struct HttpServiceResponse<
impl<T, P, S, B, X, U> Future for HttpServiceResponse<T, P, S, B, X, U> impl<T, P, S, B, X, U> Future for HttpServiceResponse<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: ServiceFactory<Request = Request>, S: NewService<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: ServiceFactory<Request = Request, Response = Request>, X: NewService<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static, U: NewService<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U: ServiceFactory<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
type Output = type Item = HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
Result<HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>, ()>; type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut this = self.as_mut().project(); if let Some(ref mut fut) = self.fut_ex {
let expect = try_ready!(fut
if let Some(fut) = this.fut_ex.as_pin_mut() { .poll()
let expect = ready!(fut .map_err(|e| log::error!("Init http service error: {:?}", e)));
.poll(cx) self.expect = Some(expect);
.map_err(|e| log::error!("Init http service error: {:?}", e)))?; self.fut_ex.take();
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
} }
if let Some(fut) = this.fut_upg.as_pin_mut() { if let Some(ref mut fut) = self.fut_upg {
let upgrade = ready!(fut let upgrade = try_ready!(fut
.poll(cx) .poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)))?; .map_err(|e| log::error!("Init http service error: {:?}", e)));
this = self.as_mut().project(); self.upgrade = Some(upgrade);
*this.upgrade = Some(upgrade); self.fut_ex.take();
this.fut_ex.set(None);
} }
let result = ready!(this let service = try_ready!(self
.fut .fut
.poll(cx) .poll()
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
Poll::Ready(result.map(|service| { Ok(Async::Ready(HttpServiceHandler::new(
let this = self.as_mut().project(); self.cfg.take().unwrap(),
HttpServiceHandler::new( service,
this.cfg.take().unwrap(), self.expect.take().unwrap(),
service, self.upgrade.take(),
this.expect.take().unwrap(), self.on_connect.clone(),
this.upgrade.take(), )))
this.on_connect.clone(),
)
}))
} }
} }
@@ -292,9 +264,9 @@ pub struct HttpServiceHandler<T, P, S, B, X, U> {
impl<T, P, S, B, X, U> HttpServiceHandler<T, P, S, B, X, U> impl<T, P, S, B, X, U> HttpServiceHandler<T, P, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
@@ -323,9 +295,9 @@ impl<T, P, S, B, X, U> Service for HttpServiceHandler<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
@@ -337,10 +309,10 @@ where
type Error = DispatchError; type Error = DispatchError;
type Future = HttpServiceHandlerResponse<T, S, B, X, U>; type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
let ready = self let ready = self
.expect .expect
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -350,7 +322,7 @@ where
let ready = self let ready = self
.srv .srv
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -360,9 +332,9 @@ where
&& ready; && ready;
if ready { if ready {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
@@ -417,7 +389,6 @@ where
} }
} }
#[pin_project]
enum State<T, S, B, X, U> enum State<T, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
@@ -430,8 +401,8 @@ where
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
H1(#[pin] h1::Dispatcher<T, S, B, X, U>), H1(h1::Dispatcher<T, S, B, X, U>),
H2(#[pin] Dispatcher<Io<T>, S, B>), H2(Dispatcher<Io<T>, S, B>),
Unknown( Unknown(
Option<( Option<(
T, T,
@@ -454,21 +425,19 @@ where
), ),
} }
#[pin_project]
pub struct HttpServiceHandlerResponse<T, S, B, X, U> pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
#[pin]
state: State<T, S, B, X, U>, state: State<T, S, B, X, U>,
} }
@@ -478,51 +447,30 @@ impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.project().state.poll(cx) match self.state {
} State::H1(ref mut disp) => disp.poll(),
} State::H2(ref mut disp) => disp.poll(),
impl<T, S, B, X, U> State<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
#[project]
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), DispatchError>> {
#[project]
match self.as_mut().project() {
State::H1(disp) => disp.poll(cx),
State::H2(disp) => disp.poll(cx),
State::Unknown(ref mut data) => { State::Unknown(ref mut data) => {
if let Some(ref mut item) = data { if let Some(ref mut item) = data {
loop { loop {
// Safety - we only write to the returned slice. // Safety - we only write to the returned slice.
let b = unsafe { item.1.bytes_mut() }; let b = unsafe { item.1.bytes_mut() };
let n = ready!(Pin::new(&mut item.0).poll_read(cx, b))?; let n = try_ready!(item.0.poll_read(b));
if n == 0 { if n == 0 {
return Poll::Ready(Ok(())); return Ok(Async::Ready(()));
} }
// Safety - we know that 'n' bytes have // Safety - we know that 'n' bytes have
// been initialized via the contract of // been initialized via the contract of
@@ -543,15 +491,15 @@ where
inner: io, inner: io,
unread: Some(buf), unread: Some(buf),
}; };
self.set(State::Handshake(Some(( self.state = State::Handshake(Some((
server::handshake(io), server::handshake(io),
cfg, cfg,
srv, srv,
peer_addr, peer_addr,
on_connect, on_connect,
)))); )));
} else { } else {
self.set(State::H1(h1::Dispatcher::with_timeout( self.state = State::H1(h1::Dispatcher::with_timeout(
io, io,
h1::Codec::new(cfg.clone()), h1::Codec::new(cfg.clone()),
cfg, cfg,
@@ -561,38 +509,36 @@ where
expect, expect,
upgrade, upgrade,
on_connect, on_connect,
))) ))
} }
self.poll(cx) self.poll()
} }
State::Handshake(ref mut data) => { State::Handshake(ref mut data) => {
let conn = if let Some(ref mut item) = data { let conn = if let Some(ref mut item) = data {
match Pin::new(&mut item.0).poll(cx) { match item.0.poll() {
Poll::Ready(Ok(conn)) => conn, Ok(Async::Ready(conn)) => conn,
Poll::Ready(Err(err)) => { Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
trace!("H2 handshake error: {}", err); trace!("H2 handshake error: {}", err);
return Poll::Ready(Err(err.into())); return Err(err.into());
} }
Poll::Pending => return Poll::Pending,
} }
} else { } else {
panic!() panic!()
}; };
let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap(); let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap();
self.set(State::H2(Dispatcher::new( self.state = State::H2(Dispatcher::new(
srv, conn, on_connect, cfg, None, peer_addr, srv, conn, on_connect, cfg, None, peer_addr,
))); ));
self.poll(cx) self.poll()
} }
} }
} }
} }
/// Wrapper for `AsyncRead + AsyncWrite` types /// Wrapper for `AsyncRead + AsyncWrite` types
#[pin_project::pin_project]
struct Io<T> { struct Io<T> {
unread: Option<BytesMut>, unread: Option<BytesMut>,
#[pin]
inner: T, inner: T,
} }
@@ -622,65 +568,21 @@ impl<T: io::Write> io::Write for Io<T> {
} }
impl<T: AsyncRead> AsyncRead for Io<T> { impl<T: AsyncRead> AsyncRead for Io<T> {
// unsafe fn initializer(&self) -> io::Initializer {
// self.get_mut().inner.initializer()
// }
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf) self.inner.prepare_uninitialized_buffer(buf)
} }
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
if let Some(mut bytes) = this.unread.take() {
let size = std::cmp::min(buf.len(), bytes.len());
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
*this.unread = Some(bytes);
}
Poll::Ready(Ok(size))
} else {
this.inner.poll_read(cx, buf)
}
}
// fn poll_read_vectored(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// bufs: &mut [io::IoSliceMut<'_>],
// ) -> Poll<io::Result<usize>> {
// self.get_mut().inner.poll_read_vectored(cx, bufs)
// }
} }
impl<T: AsyncWrite> tokio_io::AsyncWrite for Io<T> { impl<T: AsyncWrite> AsyncWrite for Io<T> {
fn poll_write( fn shutdown(&mut self) -> Poll<(), io::Error> {
self: Pin<&mut Self>, self.inner.shutdown()
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
} }
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.inner.write_buf(buf)
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
} }
} }
impl<T: IoStream> actix_server_config::IoStream for Io<T> { impl<T: IoStream> IoStream for Io<T> {
#[inline] #[inline]
fn peer_addr(&self) -> Option<net::SocketAddr> { fn peer_addr(&self) -> Option<net::SocketAddr> {
self.inner.peer_addr() self.inner.peer_addr()

View File

@@ -1,13 +1,12 @@
//! Test Various helpers for Actix applications to use during testing. //! Test Various helpers for Actix applications to use during testing.
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::{self, Read, Write}; use std::io;
use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_server_config::IoStream; use actix_server_config::IoStream;
use bytes::{Buf, Bytes, BytesMut}; use bytes::{Buf, Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use http::{HttpTryFrom, Method, Uri, Version}; use http::{HttpTryFrom, Method, Uri, Version};
use percent_encoding::percent_encode; use percent_encoding::percent_encode;
@@ -245,31 +244,14 @@ impl io::Write for TestBuffer {
} }
} }
impl AsyncRead for TestBuffer { impl AsyncRead for TestBuffer {}
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().read(buf))
}
}
impl AsyncWrite for TestBuffer { impl AsyncWrite for TestBuffer {
fn poll_write( fn shutdown(&mut self) -> Poll<(), io::Error> {
self: Pin<&mut Self>, Ok(Async::Ready(()))
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().write(buf))
} }
fn write_buf<B: Buf>(&mut self, _: &mut B) -> Poll<usize, io::Error> {
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { Ok(Async::NotReady)
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
} }
} }

View File

@@ -1,10 +1,7 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{IntoService, Service}; use actix_service::{IntoService, Service};
use actix_utils::framed::{FramedTransport, FramedTransportError}; use actix_utils::framed::{FramedTransport, FramedTransportError};
use futures::{Future, Poll};
use super::{Codec, Frame, Message}; use super::{Codec, Frame, Message};
@@ -43,9 +40,10 @@ where
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {
type Output = Result<(), FramedTransportError<S::Error, Codec>>; type Item = ();
type Error = FramedTransportError<S::Error, Codec>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Pin::new(&mut self.inner).poll(cx) self.inner.poll()
} }
} }

View File

@@ -1,9 +1,9 @@
use actix_service::ServiceFactory; use actix_service::NewService;
use bytes::Bytes; use bytes::Bytes;
use futures::future::{self, ok}; use futures::future::{self, ok};
use actix_http::{http, HttpService, Request, Response}; use actix_http::{http, HttpService, Request, Response};
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
@@ -29,63 +29,55 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
#[test] #[test]
fn test_h1_v2() { fn test_h1_v2() {
block_on(async { env_logger::init();
let srv = TestServer::start(move || { let mut srv = TestServer::new(move || {
HttpService::build() HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) });
}); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success());
let response = srv.get("/").send().await.unwrap(); let request = srv.get("/").header("x-test", "111").send();
assert!(response.status().is_success()); let response = srv.block_on(request).unwrap();
assert!(response.status().is_success());
let request = srv.get("/").header("x-test", "111").send(); // read response
let mut response = request.await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert!(response.status().is_success()); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// read response let response = srv.block_on(srv.post("/").send()).unwrap();
let bytes = response.body().await.unwrap(); assert!(response.status().is_success());
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let mut response = srv.post("/").send().await.unwrap(); // read response
assert!(response.status().is_success()); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// read response
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
} }
#[test] #[test]
fn test_connection_close() { fn test_connection_close() {
block_on(async { let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || { HttpService::build()
HttpService::build() .finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.finish(|_| ok::<_, ()>(Response::Ok().body(STR))) .map(|_| ())
.map(|_| ()) });
}); let response = srv.block_on(srv.get("/").force_close().send()).unwrap();
assert!(response.status().is_success());
let response = srv.get("/").force_close().send().await.unwrap();
assert!(response.status().is_success());
})
} }
#[test] #[test]
fn test_with_query_parameter() { fn test_with_query_parameter() {
block_on(async { let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || { HttpService::build()
HttpService::build() .finish(|req: Request| {
.finish(|req: Request| { if req.uri().query().unwrap().contains("qp=") {
if req.uri().query().unwrap().contains("qp=") { ok::<_, ()>(Response::Ok().finish())
ok::<_, ()>(Response::Ok().finish()) } else {
} else { ok::<_, ()>(Response::BadRequest().finish())
ok::<_, ()>(Response::BadRequest().finish()) }
} })
}) .map(|_| ())
.map(|_| ()) });
});
let request = srv.request(http::Method::GET, srv.url("/?qp=5")); let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send();
let response = request.send().await.unwrap(); let response = srv.block_on(request).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
})
} }

View File

@@ -1,545 +0,0 @@
#![cfg(feature = "openssl")]
use std::io;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{err, ok, ready};
use futures::stream::{once, Stream, StreamExt};
use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
async fn load_body<S>(stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
let body = stream
.map(|res| match res {
Ok(chunk) => chunk,
Err(_) => panic!(),
})
.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
ready(body)
})
.await;
Ok(body)
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> io::Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> io::Result<()> {
block_on(async {
let openssl = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_1() -> io::Result<()> {
block_on(async {
let openssl = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_body() -> io::Result<()> {
block_on(async {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
})
}
#[test]
fn test_h2_content_length() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h2_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
let data = data.clone();
pipeline_factory(openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_head_empty() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h2_body_length() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_body_chunked_explicit() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
let body =
once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_response_http_error_handling() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(service_fn2(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
}))
}))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h2_service_error() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}
#[test]
fn test_h2_on_connect() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -1,474 +0,0 @@
#![cfg(feature = "rustls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{self, err, ok};
use futures::stream::{once, Stream, StreamExt};
use rust_tls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{self, BufReader};
async fn load_body<S>(mut stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
let mut body = BytesMut::new();
while let Some(item) = stream.next().await {
body.extend_from_slice(&item?)
}
Ok(body)
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> io::Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> io::Result<()> {
block_on(async {
let rustls = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_1() -> io::Result<()> {
block_on(async {
let rustls = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_body1() -> io::Result<()> {
block_on(async {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
})
}
#[test]
fn test_h2_content_length() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h2_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
let data = data.clone();
pipeline_factory(rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_head_empty() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok()
.content_length(STR.len() as u64)
.body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary2() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h2_body_length() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(
STR.len() as u64,
body,
)),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_body_chunked_explicit() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_response_http_error_handling() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(service_fn2(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(
http::header::CONTENT_TYPE,
broken_header,
)
.body(STR),
)
}))
}))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h2_service_error() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}

View File

@@ -0,0 +1,462 @@
#![cfg(feature = "rust-tls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok, Future};
use futures::stream::{once, Stream};
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{BufReader, Result};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,480 @@
#![cfg(feature = "ssl")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::TestServer;
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{ok, Future};
use futures::stream::{once, Stream};
use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use std::io::Result;
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[test]
fn test_h2_on_connect() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
}

View File

@@ -1,27 +1,26 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
use actix_utils::framed::FramedTransport; use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::future; use futures::future::{self, ok};
use futures::{SinkExt, StreamExt}; use futures::{Future, Sink, Stream};
async fn ws_service<T: AsyncRead + AsyncWrite + Unpin>( fn ws_service<T: AsyncRead + AsyncWrite>(
(req, mut framed): (Request, Framed<T, h1::Codec>), (req, framed): (Request, Framed<T, h1::Codec>),
) -> Result<(), Error> { ) -> impl Future<Item = (), Error = Error> {
let res = ws::handshake(req.head()).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())
.await
.unwrap();
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.await
.map_err(|_| panic!()) .map_err(|_| panic!())
.and_then(|framed| {
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.map_err(|_| panic!())
})
} }
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> { fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
let msg = match msg { let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
@@ -31,56 +30,47 @@ async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
ws::Frame::Close(reason) => ws::Message::Close(reason), ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(), _ => panic!(),
}; };
Ok(msg) ok(msg)
} }
#[test] #[test]
fn test_simple() { fn test_simple() {
block_on(async { let mut srv = TestServer::new(|| {
let mut srv = TestServer::start(|| { HttpService::build()
HttpService::build() .upgrade(ws_service)
.upgrade(actix_service::service_fn(ws_service)) .finish(|_| future::ok::<_, ()>(Response::NotFound()))
.finish(|_| future::ok::<_, ()>(Response::NotFound())) });
});
// client service // client service
let mut framed = srv.ws().await.unwrap(); let framed = srv.ws().unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await; assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Binary("text".into())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await; assert_eq!(
assert_eq!( item,
item.unwrap().unwrap(), Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) );
);
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let (item, mut framed) = framed.into_future().await; .block_on(framed.send(ws::Message::Ping("text".into())))
assert_eq!( .unwrap();
item.unwrap().unwrap(), let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
ws::Frame::Pong("text".to_string().into()) assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
);
framed let framed = srv
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.await .unwrap();
.unwrap();
let (item, _framed) = framed.into_future().await; let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Close(Some(ws::CloseCode::Normal.into())) Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
); );
})
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-identity" name = "actix-identity"
version = "0.2.0-alpha.1" version = "0.1.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Identity service for actix web framework." description = "Identity service for actix web framework."
readme = "README.md" readme = "README.md"
@@ -17,14 +17,14 @@ name = "actix_identity"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "2.0.0-alpha.1", default-features = false, features = ["secure-cookies"] } actix-web = { version = "1.0.0", default-features = false, features = ["secure-cookies"] }
actix-service = "1.0.0-alpha.1" actix-service = "0.4.0"
futures = "0.3.1" futures = "0.1.25"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = "0.1.42" time = "0.1.42"
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.2"
actix-http = "0.3.0-alpha.1" actix-http = "0.2.3"
bytes = "0.4" bytes = "0.4"

View File

@@ -16,7 +16,7 @@
//! use actix_web::*; //! use actix_web::*;
//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; //! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService};
//! //!
//! async fn index(id: Identity) -> String { //! fn index(id: Identity) -> String {
//! // access request identity //! // access request identity
//! if let Some(id) = id.identity() { //! if let Some(id) = id.identity() {
//! format!("Welcome! {}", id) //! format!("Welcome! {}", id)
@@ -25,12 +25,12 @@
//! } //! }
//! } //! }
//! //!
//! async fn login(id: Identity) -> HttpResponse { //! fn login(id: Identity) -> HttpResponse {
//! id.remember("User1".to_owned()); // <- remember identity //! id.remember("User1".to_owned()); // <- remember identity
//! HttpResponse::Ok().finish() //! HttpResponse::Ok().finish()
//! } //! }
//! //!
//! async fn logout(id: Identity) -> HttpResponse { //! fn logout(id: Identity) -> HttpResponse {
//! id.forget(); // <- remove identity //! id.forget(); // <- remove identity
//! HttpResponse::Ok().finish() //! HttpResponse::Ok().finish()
//! } //! }
@@ -47,13 +47,12 @@
//! } //! }
//! ``` //! ```
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::SystemTime; use std::time::SystemTime;
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, FutureResult};
use futures::{Future, IntoFuture, Poll};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use time::Duration; use time::Duration;
@@ -166,21 +165,21 @@ where
impl FromRequest for Identity { impl FromRequest for Identity {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<Identity, Error>>; type Future = Result<Identity, Error>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ok(Identity(req.clone())) Ok(Identity(req.clone()))
} }
} }
/// Identity policy definition. /// Identity policy definition.
pub trait IdentityPolicy: Sized + 'static { pub trait IdentityPolicy: Sized + 'static {
/// The return type of the middleware /// The return type of the middleware
type Future: Future<Output = Result<Option<String>, Error>>; type Future: IntoFuture<Item = Option<String>, Error = Error>;
/// The return type of the middleware /// The return type of the middleware
type ResponseFuture: Future<Output = Result<(), Error>>; type ResponseFuture: IntoFuture<Item = (), Error = Error>;
/// Parse the session from request and load data from a service identity. /// Parse the session from request and load data from a service identity.
fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; fn from_request(&self, request: &mut ServiceRequest) -> Self::Future;
@@ -235,7 +234,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = IdentityServiceMiddleware<S, T>; type Transform = IdentityServiceMiddleware<S, T>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(IdentityServiceMiddleware { ok(IdentityServiceMiddleware {
@@ -262,39 +261,46 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.borrow_mut().poll_ready(cx) self.service.borrow_mut().poll_ready()
} }
fn call(&mut self, mut req: ServiceRequest) -> Self::Future { fn call(&mut self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone(); let srv = self.service.clone();
let backend = self.backend.clone(); let backend = self.backend.clone();
let fut = self.backend.from_request(&mut req);
async move { Box::new(
match fut.await { self.backend.from_request(&mut req).into_future().then(
Ok(id) => { move |res| match res {
req.extensions_mut() Ok(id) => {
.insert(IdentityItem { id, changed: false }); req.extensions_mut()
.insert(IdentityItem { id, changed: false });
let mut res = srv.borrow_mut().call(req).await?; Either::A(srv.borrow_mut().call(req).and_then(move |mut res| {
let id = res.request().extensions_mut().remove::<IdentityItem>(); let id =
res.request().extensions_mut().remove::<IdentityItem>();
if let Some(id) = id { if let Some(id) = id {
match backend.to_response(id.id, id.changed, &mut res).await { Either::A(
Ok(_) => Ok(res), backend
Err(e) => Ok(res.error_response(e)), .to_response(id.id, id.changed, &mut res)
} .into_future()
} else { .then(move |t| match t {
Ok(res) Ok(_) => Ok(res),
Err(e) => Ok(res.error_response(e)),
}),
)
} else {
Either::B(ok(res))
}
}))
} }
} Err(err) => Either::B(ok(req.error_response(err))),
Err(err) => Ok(req.error_response(err)), },
} ),
} )
.boxed_local()
} }
} }
@@ -541,11 +547,11 @@ impl CookieIdentityPolicy {
} }
impl IdentityPolicy for CookieIdentityPolicy { impl IdentityPolicy for CookieIdentityPolicy {
type Future = Ready<Result<Option<String>, Error>>; type Future = Result<Option<String>, Error>;
type ResponseFuture = Ready<Result<(), Error>>; type ResponseFuture = Result<(), Error>;
fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { fn from_request(&self, req: &mut ServiceRequest) -> Self::Future {
ok(self.0.load(req).map( Ok(self.0.load(req).map(
|CookieValue { |CookieValue {
identity, identity,
login_timestamp, login_timestamp,
@@ -597,7 +603,7 @@ impl IdentityPolicy for CookieIdentityPolicy {
} else { } else {
Ok(()) Ok(())
}; };
ok(()) Ok(())
} }
} }
@@ -607,7 +613,7 @@ mod tests {
use super::*; use super::*;
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
use actix_web::test::{self, block_on, TestRequest}; use actix_web::test::{self, TestRequest};
use actix_web::{web, App, Error, HttpResponse}; use actix_web::{web, App, Error, HttpResponse};
const COOKIE_KEY_MASTER: [u8; 32] = [0; 32]; const COOKIE_KEY_MASTER: [u8; 32] = [0; 32];
@@ -616,138 +622,115 @@ mod tests {
#[test] #[test]
fn test_identity() { fn test_identity() {
block_on(async { let mut srv = test::init_service(
let mut srv = test::init_service( App::new()
App::new() .wrap(IdentityService::new(
.wrap(IdentityService::new( CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) .domain("www.rust-lang.org")
.domain("www.rust-lang.org") .name(COOKIE_NAME)
.name(COOKIE_NAME) .path("/")
.path("/") .secure(true),
.secure(true), ))
)) .service(web::resource("/index").to(|id: Identity| {
.service(web::resource("/index").to(|id: Identity| { if id.identity().is_some() {
if id.identity().is_some() { HttpResponse::Created()
HttpResponse::Created() } else {
} else {
HttpResponse::Ok()
}
}))
.service(web::resource("/login").to(|id: Identity| {
id.remember(COOKIE_LOGIN.to_string());
HttpResponse::Ok() HttpResponse::Ok()
})) }
.service(web::resource("/logout").to(|id: Identity| { }))
if id.identity().is_some() { .service(web::resource("/login").to(|id: Identity| {
id.forget(); id.remember(COOKIE_LOGIN.to_string());
HttpResponse::Ok() HttpResponse::Ok()
} else { }))
HttpResponse::BadRequest() .service(web::resource("/logout").to(|id: Identity| {
} if id.identity().is_some() {
})), id.forget();
) HttpResponse::Ok()
.await; } else {
let resp = test::call_service( HttpResponse::BadRequest()
&mut srv, }
TestRequest::with_uri("/index").to_request(), })),
) );
.await; let resp =
assert_eq!(resp.status(), StatusCode::OK); test::call_service(&mut srv, TestRequest::with_uri("/index").to_request());
assert_eq!(resp.status(), StatusCode::OK);
let resp = test::call_service( let resp =
&mut srv, test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
TestRequest::with_uri("/login").to_request(), assert_eq!(resp.status(), StatusCode::OK);
) let c = resp.response().cookies().next().unwrap().to_owned();
.await;
assert_eq!(resp.status(), StatusCode::OK);
let c = resp.response().cookies().next().unwrap().to_owned();
let resp = test::call_service( let resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/index") TestRequest::with_uri("/index")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
) );
.await; assert_eq!(resp.status(), StatusCode::CREATED);
assert_eq!(resp.status(), StatusCode::CREATED);
let resp = test::call_service( let resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/logout") TestRequest::with_uri("/logout")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
) );
.await; assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.status(), StatusCode::OK); assert!(resp.headers().contains_key(header::SET_COOKIE))
assert!(resp.headers().contains_key(header::SET_COOKIE))
})
} }
#[test] #[test]
fn test_identity_max_age_time() { fn test_identity_max_age_time() {
block_on(async { let duration = Duration::days(1);
let duration = Duration::days(1); let mut srv = test::init_service(
let mut srv = test::init_service( App::new()
App::new() .wrap(IdentityService::new(
.wrap(IdentityService::new( CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) .domain("www.rust-lang.org")
.domain("www.rust-lang.org") .name(COOKIE_NAME)
.name(COOKIE_NAME) .path("/")
.path("/") .max_age_time(duration)
.max_age_time(duration) .secure(true),
.secure(true), ))
)) .service(web::resource("/login").to(|id: Identity| {
.service(web::resource("/login").to(|id: Identity| { id.remember("test".to_string());
id.remember("test".to_string()); HttpResponse::Ok()
HttpResponse::Ok() })),
})), );
) let resp =
.await; test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
let resp = test::call_service( assert_eq!(resp.status(), StatusCode::OK);
&mut srv, assert!(resp.headers().contains_key(header::SET_COOKIE));
TestRequest::with_uri("/login").to_request(), let c = resp.response().cookies().next().unwrap().to_owned();
) assert_eq!(duration, c.max_age().unwrap());
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(duration, c.max_age().unwrap());
})
} }
#[test] #[test]
fn test_identity_max_age() { fn test_identity_max_age() {
block_on(async { let seconds = 60;
let seconds = 60; let mut srv = test::init_service(
let mut srv = test::init_service( App::new()
App::new() .wrap(IdentityService::new(
.wrap(IdentityService::new( CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) .domain("www.rust-lang.org")
.domain("www.rust-lang.org") .name(COOKIE_NAME)
.name(COOKIE_NAME) .path("/")
.path("/") .max_age(seconds)
.max_age(seconds) .secure(true),
.secure(true), ))
)) .service(web::resource("/login").to(|id: Identity| {
.service(web::resource("/login").to(|id: Identity| { id.remember("test".to_string());
id.remember("test".to_string()); HttpResponse::Ok()
HttpResponse::Ok() })),
})), );
) let resp =
.await; test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
let resp = test::call_service( assert_eq!(resp.status(), StatusCode::OK);
&mut srv, assert!(resp.headers().contains_key(header::SET_COOKIE));
TestRequest::with_uri("/login").to_request(), let c = resp.response().cookies().next().unwrap().to_owned();
) assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap());
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap());
})
} }
async fn create_identity_server< fn create_identity_server<
F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static,
>( >(
f: F, f: F,
@@ -764,16 +747,13 @@ mod tests {
.secure(false) .secure(false)
.name(COOKIE_NAME)))) .name(COOKIE_NAME))))
.service(web::resource("/").to(|id: Identity| { .service(web::resource("/").to(|id: Identity| {
async move { let identity = id.identity();
let identity = id.identity(); if identity.is_none() {
if identity.is_none() { id.remember(COOKIE_LOGIN.to_string())
id.remember(COOKIE_LOGIN.to_string())
}
web::Json(identity)
} }
web::Json(identity)
})), })),
) )
.await
} }
fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> {
@@ -806,8 +786,15 @@ mod tests {
jar.get(COOKIE_NAME).unwrap().clone() jar.get(COOKIE_NAME).unwrap().clone()
} }
async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) { fn assert_logged_in(response: &mut ServiceResponse, identity: Option<&str>) {
let bytes = test::read_body(response).await; use bytes::BytesMut;
use futures::Stream;
let bytes =
test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| {
b.extend(c);
Ok::<_, Error>(b)
}))
.unwrap();
let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap(); let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap();
assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); assert_eq!(resp.as_ref().map(|s| s.borrow()), identity);
} }
@@ -887,221 +874,183 @@ mod tests {
#[test] #[test]
fn test_identity_legacy_cookie_is_set() { fn test_identity_legacy_cookie_is_set() {
block_on(async { let mut srv = create_identity_server(|c| c);
let mut srv = create_identity_server(|c| c).await; let mut resp =
let mut resp = test::call_service(&mut srv, TestRequest::with_uri("/").to_request());
test::call_service(&mut srv, TestRequest::with_uri("/").to_request()) assert_logged_in(&mut resp, None);
.await; assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN);
assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_legacy_cookie_works() { fn test_identity_legacy_cookie_works() {
block_on(async { let mut srv = create_identity_server(|c| c);
let mut srv = create_identity_server(|c| c).await; let cookie = legacy_login_cookie(COOKIE_LOGIN);
let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut resp = test::call_service(
let mut resp = test::call_service( &mut srv,
&mut srv, TestRequest::with_uri("/")
TestRequest::with_uri("/") .cookie(cookie.clone())
.cookie(cookie.clone()) .to_request(),
.to_request(), );
) assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
.await; assert_no_login_cookie(&mut resp);
assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
} }
#[test] #[test]
fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() {
block_on(async { let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
let mut srv = let cookie = legacy_login_cookie(COOKIE_LOGIN);
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; let mut resp = test::call_service(
let cookie = legacy_login_cookie(COOKIE_LOGIN); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, None);
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::NoTimestamp,
COOKIE_LOGIN, VisitTimeStampCheck::NewTimestamp,
LoginTimestampCheck::NoTimestamp, );
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() {
block_on(async { let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
let mut srv = let cookie = legacy_login_cookie(COOKIE_LOGIN);
create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let mut resp = test::call_service(
let cookie = legacy_login_cookie(COOKIE_LOGIN); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, None);
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::NewTimestamp,
COOKIE_LOGIN, VisitTimeStampCheck::NoTimestamp,
LoginTimestampCheck::NewTimestamp, );
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_login_timestamp_needed() { fn test_identity_cookie_rejected_if_login_timestamp_needed() {
block_on(async { let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
let mut srv = let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now()));
create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let mut resp = test::call_service(
let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, None);
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::NewTimestamp,
COOKIE_LOGIN, VisitTimeStampCheck::NoTimestamp,
LoginTimestampCheck::NewTimestamp, );
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_visit_timestamp_needed() { fn test_identity_cookie_rejected_if_visit_timestamp_needed() {
block_on(async { let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
let mut srv = let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; let mut resp = test::call_service(
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, None);
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::NoTimestamp,
COOKIE_LOGIN, VisitTimeStampCheck::NewTimestamp,
LoginTimestampCheck::NoTimestamp, );
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_login_timestamp_too_old() { fn test_identity_cookie_rejected_if_login_timestamp_too_old() {
block_on(async { let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
let mut srv = let cookie = login_cookie(
create_identity_server(|c| c.login_deadline(Duration::days(90))).await; COOKIE_LOGIN,
let cookie = login_cookie( Some(SystemTime::now() - Duration::days(180).to_std().unwrap()),
COOKIE_LOGIN, None,
Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), );
None, let mut resp = test::call_service(
); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, None);
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::NewTimestamp,
COOKIE_LOGIN, VisitTimeStampCheck::NoTimestamp,
LoginTimestampCheck::NewTimestamp, );
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { fn test_identity_cookie_rejected_if_visit_timestamp_too_old() {
block_on(async { let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
let mut srv = let cookie = login_cookie(
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; COOKIE_LOGIN,
let cookie = login_cookie( None,
COOKIE_LOGIN, Some(SystemTime::now() - Duration::days(180).to_std().unwrap()),
None, );
Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), let mut resp = test::call_service(
); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, None);
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::NoTimestamp,
COOKIE_LOGIN, VisitTimeStampCheck::NewTimestamp,
LoginTimestampCheck::NoTimestamp, );
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_not_updated_on_login_deadline() { fn test_identity_cookie_not_updated_on_login_deadline() {
block_on(async { let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
let mut srv = let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
create_identity_server(|c| c.login_deadline(Duration::days(90))).await; let mut resp = test::call_service(
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
) assert_no_login_cookie(&mut resp);
.await;
assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
} }
#[test] #[test]
fn test_identity_cookie_updated_on_visit_deadline() { fn test_identity_cookie_updated_on_visit_deadline() {
block_on(async { let mut srv = create_identity_server(|c| {
let mut srv = create_identity_server(|c| { c.visit_deadline(Duration::days(90))
c.visit_deadline(Duration::days(90)) .login_deadline(Duration::days(90))
.login_deadline(Duration::days(90)) });
}) let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap();
.await; let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp));
let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); let mut resp = test::call_service(
let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); &mut srv,
let mut resp = test::call_service( TestRequest::with_uri("/")
&mut srv, .cookie(cookie.clone())
TestRequest::with_uri("/") .to_request(),
.cookie(cookie.clone()) );
.to_request(), assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
) assert_login_cookie(
.await; &mut resp,
assert_login_cookie( COOKIE_LOGIN,
&mut resp, LoginTimestampCheck::OldTimestamp(timestamp),
COOKIE_LOGIN, VisitTimeStampCheck::NewTimestamp,
LoginTimestampCheck::OldTimestamp(timestamp), );
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
} }
} }

View File

@@ -1,5 +1,9 @@
# Changes # Changes
## [0.1.5] - 2019-12-07
* Multipart handling now handles NotReady during read of boundary #1189
## [0.1.4] - 2019-09-12 ## [0.1.4] - 2019-09-12
* Multipart handling now parses requests which do not end in CRLF #1038 * Multipart handling now parses requests which do not end in CRLF #1038

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-multipart" name = "actix-multipart"
version = "0.2.0-alpha.1" version = "0.1.5"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Multipart support for actix web framework." description = "Multipart support for actix web framework."
readme = "README.md" readme = "README.md"
@@ -18,18 +18,17 @@ name = "actix_multipart"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "2.0.0-alpha.1", default-features = false } actix-web = { version = "1.0.0", default-features = false }
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
actix-utils = "0.5.0-alpha.1"
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
httparse = "1.3" httparse = "1.3"
futures = "0.3.1" futures = "0.1.25"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
time = "0.1" time = "0.1"
twoway = "0.2" twoway = "0.2"
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.2"
actix-http = "0.3.0-alpha.1" actix-http = "0.2.4"

View File

@@ -1,6 +1,5 @@
//! Multipart payload support //! Multipart payload support
use actix_web::{dev::Payload, Error, FromRequest, HttpRequest}; use actix_web::{dev::Payload, Error, FromRequest, HttpRequest};
use futures::future::{ok, Ready};
use crate::server::Multipart; use crate::server::Multipart;
@@ -11,31 +10,33 @@ use crate::server::Multipart;
/// ## Server example /// ## Server example
/// ///
/// ```rust /// ```rust
/// use futures::{Stream, StreamExt}; /// # use futures::{Future, Stream};
/// # use futures::future::{ok, result, Either};
/// use actix_web::{web, HttpResponse, Error}; /// use actix_web::{web, HttpResponse, Error};
/// use actix_multipart as mp; /// use actix_multipart as mp;
/// ///
/// async fn index(mut payload: mp::Multipart) -> Result<HttpResponse, Error> { /// fn index(payload: mp::Multipart) -> impl Future<Item = HttpResponse, Error = Error> {
/// // iterate over multipart stream /// payload.from_err() // <- get multipart stream for current request
/// while let Some(item) = payload.next().await { /// .and_then(|field| { // <- iterate over multipart items
/// let mut field = item?;
///
/// // Field in turn is stream of *Bytes* object /// // Field in turn is stream of *Bytes* object
/// while let Some(chunk) = field.next().await { /// field.from_err()
/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?)); /// .fold((), |_, chunk| {
/// } /// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk));
/// } /// Ok::<_, Error>(())
/// Ok(HttpResponse::Ok().into()) /// })
/// })
/// .fold((), |_, _| Ok::<_, Error>(()))
/// .map(|_| HttpResponse::Ok().into())
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
impl FromRequest for Multipart { impl FromRequest for Multipart {
type Error = Error; type Error = Error;
type Future = Ready<Result<Multipart, Error>>; type Future = Result<Multipart, Error>;
type Config = (); type Config = ();
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
ok(Multipart::new(req.headers(), payload.take())) Ok(Multipart::new(req.headers(), payload.take()))
} }
} }

View File

@@ -1,17 +1,15 @@
//! Multipart payload support //! Multipart payload support
use std::cell::{Cell, RefCell, RefMut}; use std::cell::{Cell, RefCell, RefMut};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::{cmp, fmt}; use std::{cmp, fmt};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::stream::{LocalBoxStream, Stream, StreamExt}; use futures::task::{current as current_task, Task};
use futures::{Async, Poll, Stream};
use httparse; use httparse;
use mime; use mime;
use actix_utils::task::LocalWaker;
use actix_web::error::{ParseError, PayloadError}; use actix_web::error::{ParseError, PayloadError};
use actix_web::http::header::{ use actix_web::http::header::{
self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, self, ContentDisposition, HeaderMap, HeaderName, HeaderValue,
@@ -62,7 +60,7 @@ impl Multipart {
/// Create multipart instance for boundary. /// Create multipart instance for boundary.
pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static, S: Stream<Item = Bytes, Error = PayloadError> + 'static,
{ {
match Self::boundary(headers) { match Self::boundary(headers) {
Ok(boundary) => Multipart { Ok(boundary) => Multipart {
@@ -106,25 +104,22 @@ impl Multipart {
} }
impl Stream for Multipart { impl Stream for Multipart {
type Item = Result<Field, MultipartError>; type Item = Field;
type Error = MultipartError;
fn poll_next( fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
if let Some(err) = self.error.take() { if let Some(err) = self.error.take() {
Poll::Ready(Some(Err(err))) Err(err)
} else if self.safety.current() { } else if self.safety.current() {
let this = self.get_mut(); let mut inner = self.inner.as_mut().unwrap().borrow_mut();
let mut inner = this.inner.as_mut().unwrap().borrow_mut(); if let Some(mut payload) = inner.payload.get_mut(&self.safety) {
if let Some(mut payload) = inner.payload.get_mut(&this.safety) { payload.poll_stream()?;
payload.poll_stream(cx)?;
} }
inner.poll(&this.safety, cx) inner.poll(&self.safety)
} else if !self.safety.is_clean() { } else if !self.safety.is_clean() {
Poll::Ready(Some(Err(MultipartError::NotConsumed))) Err(MultipartError::NotConsumed)
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -243,13 +238,9 @@ impl InnerMultipart {
Ok(Some(eof)) Ok(Some(eof))
} }
fn poll( fn poll(&mut self, safety: &Safety) -> Poll<Option<Field>, MultipartError> {
&mut self,
safety: &Safety,
cx: &mut Context,
) -> Poll<Option<Result<Field, MultipartError>>> {
if self.state == InnerState::Eof { if self.state == InnerState::Eof {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
// release field // release field
loop { loop {
@@ -258,13 +249,10 @@ impl InnerMultipart {
if safety.current() { if safety.current() {
let stop = match self.item { let stop = match self.item {
InnerMultipartItem::Field(ref mut field) => { InnerMultipartItem::Field(ref mut field) => {
match field.borrow_mut().poll(safety) { match field.borrow_mut().poll(safety)? {
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
Poll::Ready(Some(Ok(_))) => continue, Async::Ready(Some(_)) => continue,
Poll::Ready(Some(Err(e))) => { Async::Ready(None) => true,
return Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => true,
} }
} }
InnerMultipartItem::None => false, InnerMultipartItem::None => false,
@@ -289,12 +277,12 @@ impl InnerMultipart {
Some(eof) => { Some(eof) => {
if eof { if eof {
self.state = InnerState::Eof; self.state = InnerState::Eof;
return Poll::Ready(None); return Ok(Async::Ready(None));
} else { } else {
self.state = InnerState::Headers; self.state = InnerState::Headers;
} }
} }
None => return Poll::Pending, None => return Ok(Async::NotReady),
} }
} }
// read boundary // read boundary
@@ -303,11 +291,11 @@ impl InnerMultipart {
&mut *payload, &mut *payload,
&self.boundary, &self.boundary,
)? { )? {
None => return Poll::Pending, None => return Ok(Async::NotReady),
Some(eof) => { Some(eof) => {
if eof { if eof {
self.state = InnerState::Eof; self.state = InnerState::Eof;
return Poll::Ready(None); return Ok(Async::Ready(None));
} else { } else {
self.state = InnerState::Headers; self.state = InnerState::Headers;
} }
@@ -323,14 +311,14 @@ impl InnerMultipart {
self.state = InnerState::Boundary; self.state = InnerState::Boundary;
headers headers
} else { } else {
return Poll::Pending; return Ok(Async::NotReady);
} }
} else { } else {
unreachable!() unreachable!()
} }
} else { } else {
log::debug!("NotReady: field is in flight"); log::debug!("NotReady: field is in flight");
return Poll::Pending; return Ok(Async::NotReady);
}; };
// content type // content type
@@ -347,7 +335,7 @@ impl InnerMultipart {
// nested multipart stream // nested multipart stream
if mt.type_() == mime::MULTIPART { if mt.type_() == mime::MULTIPART {
Poll::Ready(Some(Err(MultipartError::Nested))) Err(MultipartError::Nested)
} else { } else {
let field = Rc::new(RefCell::new(InnerField::new( let field = Rc::new(RefCell::new(InnerField::new(
self.payload.clone(), self.payload.clone(),
@@ -356,7 +344,12 @@ impl InnerMultipart {
)?)); )?));
self.item = InnerMultipartItem::Field(Rc::clone(&field)); self.item = InnerMultipartItem::Field(Rc::clone(&field));
Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field)))) Ok(Async::Ready(Some(Field::new(
safety.clone(),
headers,
mt,
field,
))))
} }
} }
} }
@@ -416,21 +409,23 @@ impl Field {
} }
impl Stream for Field { impl Stream for Field {
type Item = Result<Bytes, MultipartError>; type Item = Bytes;
type Error = MultipartError;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.safety.current() { if self.safety.current() {
let mut inner = self.inner.borrow_mut(); let mut inner = self.inner.borrow_mut();
if let Some(mut payload) = if let Some(mut payload) =
inner.payload.as_ref().unwrap().get_mut(&self.safety) inner.payload.as_ref().unwrap().get_mut(&self.safety)
{ {
payload.poll_stream(cx)?; payload.poll_stream()?;
} }
inner.poll(&self.safety) inner.poll(&self.safety)
} else if !self.safety.is_clean() { } else if !self.safety.is_clean() {
Poll::Ready(Some(Err(MultipartError::NotConsumed))) Err(MultipartError::NotConsumed)
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -487,9 +482,9 @@ impl InnerField {
fn read_len( fn read_len(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
size: &mut u64, size: &mut u64,
) -> Poll<Option<Result<Bytes, MultipartError>>> { ) -> Poll<Option<Bytes>, MultipartError> {
if *size == 0 { if *size == 0 {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
match payload.read_max(*size)? { match payload.read_max(*size)? {
Some(mut chunk) => { Some(mut chunk) => {
@@ -499,13 +494,13 @@ impl InnerField {
if !chunk.is_empty() { if !chunk.is_empty() {
payload.unprocessed(chunk); payload.unprocessed(chunk);
} }
Poll::Ready(Some(Ok(ch))) Ok(Async::Ready(Some(ch)))
} }
None => { None => {
if payload.eof && (*size != 0) { if payload.eof && (*size != 0) {
Poll::Ready(Some(Err(MultipartError::Incomplete))) Err(MultipartError::Incomplete)
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -517,15 +512,15 @@ impl InnerField {
fn read_stream( fn read_stream(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
boundary: &str, boundary: &str,
) -> Poll<Option<Result<Bytes, MultipartError>>> { ) -> Poll<Option<Bytes>, MultipartError> {
let mut pos = 0; let mut pos = 0;
let len = payload.buf.len(); let len = payload.buf.len();
if len == 0 { if len == 0 {
return if payload.eof { return if payload.eof {
Poll::Ready(Some(Err(MultipartError::Incomplete))) Err(MultipartError::Incomplete)
} else { } else {
Poll::Pending Ok(Async::NotReady)
}; };
} }
@@ -542,10 +537,10 @@ impl InnerField {
if let Some(b_len) = b_len { if let Some(b_len) = b_len {
let b_size = boundary.len() + b_len; let b_size = boundary.len() + b_len;
if len < b_size { if len < b_size {
return Poll::Pending; return Ok(Async::NotReady);
} else if &payload.buf[b_len..b_size] == boundary.as_bytes() { } else if &payload.buf[b_len..b_size] == boundary.as_bytes() {
// found boundary // found boundary
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
} }
} }
@@ -557,9 +552,9 @@ impl InnerField {
// check if we have enough data for boundary detection // check if we have enough data for boundary detection
if cur + 4 > len { if cur + 4 > len {
if cur > 0 { if cur > 0 {
Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze())))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} else { } else {
// check boundary // check boundary
@@ -570,7 +565,7 @@ impl InnerField {
{ {
if cur != 0 { if cur != 0 {
// return buffer // return buffer
Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze())))
} else { } else {
pos = cur + 1; pos = cur + 1;
continue; continue;
@@ -582,51 +577,49 @@ impl InnerField {
} }
} }
} else { } else {
Poll::Ready(Some(Ok(payload.buf.take().freeze()))) Ok(Async::Ready(Some(payload.buf.take().freeze())))
}; };
} }
} }
fn poll(&mut self, s: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> { fn poll(&mut self, s: &Safety) -> Poll<Option<Bytes>, MultipartError> {
if self.payload.is_none() { if self.payload.is_none() {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s)
{ {
if !self.eof { if !self.eof {
let res = if let Some(ref mut len) = self.length { let res = if let Some(ref mut len) = self.length {
InnerField::read_len(&mut *payload, len) InnerField::read_len(&mut *payload, len)?
} else { } else {
InnerField::read_stream(&mut *payload, &self.boundary) InnerField::read_stream(&mut *payload, &self.boundary)?
}; };
match res { match res {
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))), Async::Ready(Some(bytes)) => return Ok(Async::Ready(Some(bytes))),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Async::Ready(None) => self.eof = true,
Poll::Ready(None) => self.eof = true,
} }
} }
match payload.readline() { match payload.readline()? {
Ok(None) => Poll::Ready(None), None => Async::NotReady,
Ok(Some(line)) => { Some(line) => {
if line.as_ref() != b"\r\n" { if line.as_ref() != b"\r\n" {
log::warn!("multipart field did not read all the data or it is malformed"); log::warn!("multipart field did not read all the data or it is malformed");
} }
Poll::Ready(None) Async::Ready(None)
} }
Err(e) => Poll::Ready(Some(Err(e))),
} }
} else { } else {
Poll::Pending Async::NotReady
}; };
if let Poll::Ready(None) = result { if Async::Ready(None) == result {
self.payload.take(); self.payload.take();
} }
result Ok(result)
} }
} }
@@ -666,7 +659,7 @@ impl Clone for PayloadRef {
/// most task. /// most task.
#[derive(Debug)] #[derive(Debug)]
struct Safety { struct Safety {
task: LocalWaker, task: Option<Task>,
level: usize, level: usize,
payload: Rc<PhantomData<bool>>, payload: Rc<PhantomData<bool>>,
clean: Rc<Cell<bool>>, clean: Rc<Cell<bool>>,
@@ -676,7 +669,7 @@ impl Safety {
fn new() -> Safety { fn new() -> Safety {
let payload = Rc::new(PhantomData); let payload = Rc::new(PhantomData);
Safety { Safety {
task: LocalWaker::new(), task: None,
level: Rc::strong_count(&payload), level: Rc::strong_count(&payload),
clean: Rc::new(Cell::new(true)), clean: Rc::new(Cell::new(true)),
payload, payload,
@@ -690,17 +683,17 @@ impl Safety {
fn is_clean(&self) -> bool { fn is_clean(&self) -> bool {
self.clean.get() self.clean.get()
} }
}
fn clone(&self, cx: &mut Context) -> Safety { impl Clone for Safety {
fn clone(&self) -> Safety {
let payload = Rc::clone(&self.payload); let payload = Rc::clone(&self.payload);
let s = Safety { Safety {
task: LocalWaker::new(), task: Some(current_task()),
level: Rc::strong_count(&payload), level: Rc::strong_count(&payload),
clean: self.clean.clone(), clean: self.clean.clone(),
payload, payload,
}; }
s.task.register(cx.waker());
s
} }
} }
@@ -711,7 +704,7 @@ impl Drop for Safety {
self.clean.set(true); self.clean.set(true);
} }
if let Some(task) = self.task.take() { if let Some(task) = self.task.take() {
task.wake() task.notify()
} }
} }
} }
@@ -720,32 +713,31 @@ impl Drop for Safety {
struct PayloadBuffer { struct PayloadBuffer {
eof: bool, eof: bool,
buf: BytesMut, buf: BytesMut,
stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>, stream: Box<dyn Stream<Item = Bytes, Error = PayloadError>>,
} }
impl PayloadBuffer { impl PayloadBuffer {
/// Create new `PayloadBuffer` instance /// Create new `PayloadBuffer` instance
fn new<S>(stream: S) -> Self fn new<S>(stream: S) -> Self
where where
S: Stream<Item = Result<Bytes, PayloadError>> + 'static, S: Stream<Item = Bytes, Error = PayloadError> + 'static,
{ {
PayloadBuffer { PayloadBuffer {
eof: false, eof: false,
buf: BytesMut::new(), buf: BytesMut::new(),
stream: stream.boxed_local(), stream: Box::new(stream),
} }
} }
fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> { fn poll_stream(&mut self) -> Result<(), PayloadError> {
loop { loop {
match Pin::new(&mut self.stream).poll_next(cx) { match self.stream.poll()? {
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), Async::Ready(Some(data)) => self.buf.extend_from_slice(&data),
Poll::Ready(Some(Err(e))) => return Err(e), Async::Ready(None) => {
Poll::Ready(None) => {
self.eof = true; self.eof = true;
return Ok(()); return Ok(());
} }
Poll::Pending => return Ok(()), Async::NotReady => return Ok(()),
} }
} }
} }
@@ -808,14 +800,13 @@ impl PayloadBuffer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use actix_http::h1::Payload; use actix_http::h1::Payload;
use actix_utils::mpsc;
use actix_web::http::header::{DispositionParam, DispositionType};
use actix_web::test::block_on;
use bytes::Bytes; use bytes::Bytes;
use futures::future::lazy; use futures::unsync::mpsc;
use super::*;
use actix_web::http::header::{DispositionParam, DispositionType};
use actix_web::test::run_on;
#[test] #[test]
fn test_boundary() { fn test_boundary() {
@@ -861,12 +852,48 @@ mod tests {
} }
fn create_stream() -> ( fn create_stream() -> (
mpsc::Sender<Result<Bytes, PayloadError>>, mpsc::UnboundedSender<Result<Bytes, PayloadError>>,
impl Stream<Item = Result<Bytes, PayloadError>>, impl Stream<Item = Bytes, Error = PayloadError>,
) { ) {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::unbounded();
(tx, rx.map(|res| res.map_err(|_| panic!()))) (tx, rx.map_err(|_| panic!()).and_then(|res| res))
}
// Stream that returns from a Bytes, one char at a time and NotReady every other poll()
struct SlowStream {
bytes: Bytes,
pos: usize,
ready: bool,
}
impl SlowStream {
fn new(bytes: Bytes) -> SlowStream {
return SlowStream {
bytes: bytes,
pos: 0,
ready: false,
}
}
}
impl Stream for SlowStream {
type Item = Bytes;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if !self.ready {
self.ready = true;
return Ok(Async::NotReady);
}
if self.pos == self.bytes.len() {
return Ok(Async::Ready(None));
}
let res = Ok(Async::Ready(Some(self.bytes.slice(self.pos, self.pos + 1))));
self.pos += 1;
self.ready = false;
res
}
} }
fn create_simple_request_with_header() -> (Bytes, HeaderMap) { fn create_simple_request_with_header() -> (Bytes, HeaderMap) {
@@ -893,28 +920,28 @@ mod tests {
#[test] #[test]
fn test_multipart_no_end_crlf() { fn test_multipart_no_end_crlf() {
block_on(async { run_on(|| {
let (sender, payload) = create_stream(); let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header(); let (bytes, headers) = create_simple_request_with_header();
let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf
sender.send(Ok(bytes_stripped)).unwrap(); sender.unbounded_send(Ok(bytes_stripped)).unwrap();
drop(sender); // eof drop(sender); // eof
let mut multipart = Multipart::new(&headers, payload); let mut multipart = Multipart::new(&headers, payload);
match multipart.next().await.unwrap() { match multipart.poll().unwrap() {
Ok(_) => (), Async::Ready(Some(_)) => (),
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.next().await.unwrap() { match multipart.poll().unwrap() {
Ok(_) => (), Async::Ready(Some(_)) => (),
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.next().await { match multipart.poll().unwrap() {
None => (), Async::Ready(None) => (),
_ => unreachable!(), _ => unreachable!(),
} }
}) })
@@ -922,15 +949,15 @@ mod tests {
#[test] #[test]
fn test_multipart() { fn test_multipart() {
block_on(async { run_on(|| {
let (sender, payload) = create_stream(); let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header(); let (bytes, headers) = create_simple_request_with_header();
sender.send(Ok(bytes)).unwrap(); sender.unbounded_send(Ok(bytes)).unwrap();
let mut multipart = Multipart::new(&headers, payload); let mut multipart = Multipart::new(&headers, payload);
match multipart.next().await { match multipart.poll().unwrap() {
Some(Ok(mut field)) => { Async::Ready(Some(mut field)) => {
let cd = field.content_disposition().unwrap(); let cd = field.content_disposition().unwrap();
assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.disposition, DispositionType::FormData);
assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into()));
@@ -938,53 +965,75 @@ mod tests {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.next().await.unwrap() { match field.poll().unwrap() {
Ok(chunk) => assert_eq!(chunk, "test"), Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"),
_ => unreachable!(), _ => unreachable!(),
} }
match field.next().await { match field.poll().unwrap() {
None => (), Async::Ready(None) => (),
_ => unreachable!(), _ => unreachable!(),
} }
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.next().await.unwrap() { match multipart.poll().unwrap() {
Ok(mut field) => { Async::Ready(Some(mut field)) => {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.next().await { match field.poll() {
Some(Ok(chunk)) => assert_eq!(chunk, "data"), Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"),
_ => unreachable!(), _ => unreachable!(),
} }
match field.next().await { match field.poll() {
None => (), Ok(Async::Ready(None)) => (),
_ => unreachable!(), _ => unreachable!(),
} }
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.next().await { match multipart.poll().unwrap() {
None => (), Async::Ready(None) => (),
_ => unreachable!(), _ => unreachable!(),
} }
}); });
} }
// Retries on NotReady
fn loop_poll<T>(stream: &mut T) -> Poll<Option<T::Item>, T::Error>
where T: Stream {
loop {
let r = stream.poll();
match r {
Ok(Async::NotReady) => continue,
_ => return r,
}
}
}
// Loops polling, collecting all bytes until end-of-field
fn get_whole_field(field: &mut Field) -> BytesMut {
let mut b = BytesMut::new();
loop {
match loop_poll(field) {
Ok(Async::Ready(Some(chunk))) => b.extend_from_slice(&chunk),
Ok(Async::Ready(None)) => return b,
_ => unreachable!(),
}
}
}
#[test] #[test]
fn test_stream() { fn test_stream() {
block_on(async { run_on(|| {
let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header(); let (bytes, headers) = create_simple_request_with_header();
let payload = SlowStream::new(bytes);
sender.send(Ok(bytes)).unwrap();
let mut multipart = Multipart::new(&headers, payload); let mut multipart = Multipart::new(&headers, payload);
match multipart.next().await.unwrap() { match loop_poll(&mut multipart).unwrap() {
Ok(mut field) => { Async::Ready(Some(mut field)) => {
let cd = field.content_disposition().unwrap(); let cd = field.content_disposition().unwrap();
assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.disposition, DispositionType::FormData);
assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into()));
@@ -992,64 +1041,45 @@ mod tests {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.next().await.unwrap() { assert_eq!(get_whole_field(&mut field), "test");
Ok(chunk) => assert_eq!(chunk, "test"),
_ => unreachable!(),
}
match field.next().await {
None => (),
_ => unreachable!(),
}
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.next().await { match loop_poll(&mut multipart).unwrap() {
Some(Ok(mut field)) => { Async::Ready(Some(mut field)) => {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.next().await { assert_eq!(get_whole_field(&mut field), "data");
Some(Ok(chunk)) => assert_eq!(chunk, "data"),
_ => unreachable!(),
}
match field.next().await {
None => (),
_ => unreachable!(),
}
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.next().await {
None => (),
_ => unreachable!(),
}
}); });
} }
#[test] #[test]
fn test_basic() { fn test_basic() {
block_on(async { run_on(|| {
let (_, payload) = Payload::create(false); let (_, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
assert_eq!(payload.buf.len(), 0); assert_eq!(payload.buf.len(), 0);
lazy(|cx| payload.poll_stream(cx)).await.unwrap(); payload.poll_stream().unwrap();
assert_eq!(None, payload.read_max(1).unwrap()); assert_eq!(None, payload.read_max(1).unwrap());
}) })
} }
#[test] #[test]
fn test_eof() { fn test_eof() {
block_on(async { run_on(|| {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(4).unwrap()); assert_eq!(None, payload.read_max(4).unwrap());
sender.feed_data(Bytes::from("data")); sender.feed_data(Bytes::from("data"));
sender.feed_eof(); sender.feed_eof();
lazy(|cx| payload.poll_stream(cx)).await.unwrap(); payload.poll_stream().unwrap();
assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap()); assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap());
assert_eq!(payload.buf.len(), 0); assert_eq!(payload.buf.len(), 0);
@@ -1060,24 +1090,24 @@ mod tests {
#[test] #[test]
fn test_err() { fn test_err() {
block_on(async { run_on(|| {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(1).unwrap()); assert_eq!(None, payload.read_max(1).unwrap());
sender.set_error(PayloadError::Incomplete(None)); sender.set_error(PayloadError::Incomplete(None));
lazy(|cx| payload.poll_stream(cx)).await.err().unwrap(); payload.poll_stream().err().unwrap();
}) })
} }
#[test] #[test]
fn test_readmax() { fn test_readmax() {
block_on(async { run_on(|| {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
lazy(|cx| payload.poll_stream(cx)).await.unwrap(); payload.poll_stream().unwrap();
assert_eq!(payload.buf.len(), 10); assert_eq!(payload.buf.len(), 10);
assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap()); assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap());
@@ -1090,7 +1120,7 @@ mod tests {
#[test] #[test]
fn test_readexactly() { fn test_readexactly() {
block_on(async { run_on(|| {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
@@ -1098,7 +1128,7 @@ mod tests {
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
lazy(|cx| payload.poll_stream(cx)).await.unwrap(); payload.poll_stream().unwrap();
assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2)); assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2));
assert_eq!(payload.buf.len(), 8); assert_eq!(payload.buf.len(), 8);
@@ -1110,7 +1140,7 @@ mod tests {
#[test] #[test]
fn test_readuntil() { fn test_readuntil() {
block_on(async { run_on(|| {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
@@ -1118,7 +1148,7 @@ mod tests {
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
lazy(|cx| payload.poll_stream(cx)).await.unwrap(); payload.poll_stream().unwrap();
assert_eq!( assert_eq!(
Some(Bytes::from("line")), Some(Bytes::from("line")),

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-session" name = "actix-session"
version = "0.3.0-alpha.1" version = "0.2.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Session for actix web framework." description = "Session for actix web framework."
readme = "README.md" readme = "README.md"
@@ -24,15 +24,15 @@ default = ["cookie-session"]
cookie-session = ["actix-web/secure-cookies"] cookie-session = ["actix-web/secure-cookies"]
[dependencies] [dependencies]
actix-web = "2.0.0-alpha.1" actix-web = "1.0.0"
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.3.1" futures = "0.1.25"
hashbrown = "0.6.3" hashbrown = "0.5.0"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = "0.1.42" time = "0.1.42"
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.2"

View File

@@ -17,7 +17,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
@@ -25,7 +24,8 @@ use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::http::{header::SET_COOKIE, HeaderValue}; use actix_web::http::{header::SET_COOKIE, HeaderValue};
use actix_web::{Error, HttpMessage, ResponseError}; use actix_web::{Error, HttpMessage, ResponseError};
use derive_more::{Display, From}; use derive_more::{Display, From};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Future, FutureResult};
use futures::Poll;
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
use crate::{Session, SessionStatus}; use crate::{Session, SessionStatus};
@@ -284,7 +284,7 @@ where
type Error = S::Error; type Error = S::Error;
type InitError = (); type InitError = ();
type Transform = CookieSessionMiddleware<S>; type Transform = CookieSessionMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CookieSessionMiddleware { ok(CookieSessionMiddleware {
@@ -309,10 +309,10 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
/// On first request, a new session cookie is returned in response, regardless /// On first request, a new session cookie is returned in response, regardless
@@ -325,36 +325,29 @@ where
let (is_new, state) = self.inner.load(&req); let (is_new, state) = self.inner.load(&req);
Session::set_session(state.into_iter(), &mut req); Session::set_session(state.into_iter(), &mut req);
let fut = self.service.call(req); Box::new(self.service.call(req).map(move |mut res| {
match Session::get_changes(&mut res) {
async move { (SessionStatus::Changed, Some(state))
fut.await.map(|mut res| { | (SessionStatus::Renewed, Some(state)) => {
match Session::get_changes(&mut res) { res.checked_expr(|res| inner.set_cookie(res, state))
(SessionStatus::Changed, Some(state)) }
| (SessionStatus::Renewed, Some(state)) => { (SessionStatus::Unchanged, _) =>
res.checked_expr(|res| inner.set_cookie(res, state)) // set a new session cookie upon first request (new client)
} {
(SessionStatus::Unchanged, _) => if is_new {
// set a new session cookie upon first request (new client) let state: HashMap<String, String> = HashMap::new();
{ res.checked_expr(|res| inner.set_cookie(res, state.into_iter()))
if is_new { } else {
let state: HashMap<String, String> = HashMap::new();
res.checked_expr(|res| {
inner.set_cookie(res, state.into_iter())
})
} else {
res
}
}
(SessionStatus::Purged, _) => {
let _ = inner.remove_cookie(&mut res);
res res
} }
_ => res,
} }
}) (SessionStatus::Purged, _) => {
} let _ = inner.remove_cookie(&mut res);
.boxed_local() res
}
_ => res,
}
}))
} }
} }
@@ -366,123 +359,101 @@ mod tests {
#[test] #[test]
fn cookie_session() { fn cookie_session() {
test::block_on(async { let mut app = test::init_service(
let mut app = test::init_service( App::new()
App::new() .wrap(CookieSession::signed(&[0; 32]).secure(false))
.wrap(CookieSession::signed(&[0; 32]).secure(false)) .service(web::resource("/").to(|ses: Session| {
.service(web::resource("/").to(|ses: Session| { let _ = ses.set("counter", 100);
async move { "test"
let _ = ses.set("counter", 100); })),
"test" );
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = app.call(request).await.unwrap(); let response = test::block_on(app.call(request)).unwrap();
assert!(response assert!(response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.is_some()); .is_some());
})
} }
#[test] #[test]
fn private_cookie() { fn private_cookie() {
test::block_on(async { let mut app = test::init_service(
let mut app = test::init_service( App::new()
App::new() .wrap(CookieSession::private(&[0; 32]).secure(false))
.wrap(CookieSession::private(&[0; 32]).secure(false)) .service(web::resource("/").to(|ses: Session| {
.service(web::resource("/").to(|ses: Session| { let _ = ses.set("counter", 100);
async move { "test"
let _ = ses.set("counter", 100); })),
"test" );
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = app.call(request).await.unwrap(); let response = test::block_on(app.call(request)).unwrap();
assert!(response assert!(response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.is_some()); .is_some());
})
} }
#[test] #[test]
fn cookie_session_extractor() { fn cookie_session_extractor() {
test::block_on(async { let mut app = test::init_service(
let mut app = test::init_service( App::new()
App::new() .wrap(CookieSession::signed(&[0; 32]).secure(false))
.wrap(CookieSession::signed(&[0; 32]).secure(false)) .service(web::resource("/").to(|ses: Session| {
.service(web::resource("/").to(|ses: Session| { let _ = ses.set("counter", 100);
async move { "test"
let _ = ses.set("counter", 100); })),
"test" );
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = app.call(request).await.unwrap(); let response = test::block_on(app.call(request)).unwrap();
assert!(response assert!(response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.is_some()); .is_some());
})
} }
#[test] #[test]
fn basics() { fn basics() {
test::block_on(async { let mut app = test::init_service(
let mut app = test::init_service( App::new()
App::new() .wrap(
.wrap( CookieSession::signed(&[0; 32])
CookieSession::signed(&[0; 32]) .path("/test/")
.path("/test/") .name("actix-test")
.name("actix-test") .domain("localhost")
.domain("localhost") .http_only(true)
.http_only(true) .same_site(SameSite::Lax)
.same_site(SameSite::Lax) .max_age(100),
.max_age(100), )
) .service(web::resource("/").to(|ses: Session| {
.service(web::resource("/").to(|ses: Session| { let _ = ses.set("counter", 100);
async move { "test"
let _ = ses.set("counter", 100); }))
"test" .service(web::resource("/test/").to(|ses: Session| {
} let val: usize = ses.get("counter").unwrap().unwrap();
})) format!("counter: {}", val)
.service(web::resource("/test/").to(|ses: Session| { })),
async move { );
let val: usize = ses.get("counter").unwrap().unwrap();
format!("counter: {}", val)
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = app.call(request).await.unwrap(); let response = test::block_on(app.call(request)).unwrap();
let cookie = response let cookie = response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-test") .find(|c| c.name() == "actix-test")
.unwrap() .unwrap()
.clone(); .clone();
assert_eq!(cookie.path().unwrap(), "/test/"); assert_eq!(cookie.path().unwrap(), "/test/");
let request = test::TestRequest::with_uri("/test/") let request = test::TestRequest::with_uri("/test/")
.cookie(cookie) .cookie(cookie)
.to_request(); .to_request();
let body = test::read_response(&mut app, request).await; let body = test::read_response(&mut app, request);
assert_eq!(body, Bytes::from_static(b"counter: 100")); assert_eq!(body, Bytes::from_static(b"counter: 100"));
})
} }
} }

View File

@@ -47,7 +47,6 @@ use std::rc::Rc;
use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse};
use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; use actix_web::{Error, FromRequest, HttpMessage, HttpRequest};
use futures::future::{ok, Ready};
use hashbrown::HashMap; use hashbrown::HashMap;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@@ -231,12 +230,12 @@ impl Session {
/// ``` /// ```
impl FromRequest for Session { impl FromRequest for Session {
type Error = Error; type Error = Error;
type Future = Ready<Result<Session, Error>>; type Future = Result<Session, Error>;
type Config = (); type Config = ();
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ok(Session::get_session(&mut *req.extensions_mut())) Ok(Session::get_session(&mut *req.extensions_mut()))
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web-codegen" name = "actix-web-codegen"
version = "0.2.0-alpha.1" version = "0.1.3"
description = "Actix web proc macros" description = "Actix web proc macros"
readme = "README.md" readme = "README.md"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
@@ -17,7 +17,7 @@ syn = { version = "1", features = ["full", "parsing"] }
proc-macro2 = "1" proc-macro2 = "1"
[dev-dependencies] [dev-dependencies]
actix-web = { version = "2.0.0-alph.a" } actix-web = { version = "1.0.0" }
actix-http = { version = "0.3.0-alpha.1", features=["openssl"] } actix-http = { version = "0.2.4", features=["ssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } actix-http-test = { version = "0.2.0", features=["ssl"] }
futures = { version = "0.3.1" } futures = { version = "0.1" }

View File

@@ -35,8 +35,8 @@
//! use futures::{future, Future}; //! use futures::{future, Future};
//! //!
//! #[get("/test")] //! #[get("/test")]
//! async fn async_test() -> Result<HttpResponse, actix_web::Error> { //! fn async_test() -> impl Future<Item=HttpResponse, Error=actix_web::Error> {
//! Ok(HttpResponse::Ok().finish()) //! future::ok(HttpResponse::Ok().finish())
//! } //! }
//! ``` //! ```

View File

@@ -13,7 +13,7 @@ enum ResourceType {
impl ToTokens for ResourceType { impl ToTokens for ResourceType {
fn to_tokens(&self, stream: &mut TokenStream2) { fn to_tokens(&self, stream: &mut TokenStream2) {
let ident = match self { let ident = match self {
ResourceType::Async => "to", ResourceType::Async => "to_async",
ResourceType::Sync => "to", ResourceType::Sync => "to",
}; };
let ident = Ident::new(ident, Span::call_site()); let ident = Ident::new(ident, Span::call_site());

View File

@@ -1,163 +1,157 @@
use actix_http::HttpService; use actix_http::HttpService;
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
use actix_web::{http, web::Path, App, HttpResponse, Responder}; use actix_web::{http, web::Path, App, HttpResponse, Responder};
use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace};
use futures::{future, Future}; use futures::{future, Future};
#[get("/test")] #[get("/test")]
async fn test() -> impl Responder { fn test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[put("/test")] #[put("/test")]
async fn put_test() -> impl Responder { fn put_test() -> impl Responder {
HttpResponse::Created() HttpResponse::Created()
} }
#[patch("/test")] #[patch("/test")]
async fn patch_test() -> impl Responder { fn patch_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[post("/test")] #[post("/test")]
async fn post_test() -> impl Responder { fn post_test() -> impl Responder {
HttpResponse::NoContent() HttpResponse::NoContent()
} }
#[head("/test")] #[head("/test")]
async fn head_test() -> impl Responder { fn head_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[connect("/test")] #[connect("/test")]
async fn connect_test() -> impl Responder { fn connect_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[options("/test")] #[options("/test")]
async fn options_test() -> impl Responder { fn options_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[trace("/test")] #[trace("/test")]
async fn trace_test() -> impl Responder { fn trace_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[get("/test")] #[get("/test")]
fn auto_async() -> impl Future<Output = Result<HttpResponse, actix_web::Error>> { fn auto_async() -> impl Future<Item = HttpResponse, Error = actix_web::Error> {
future::ok(HttpResponse::Ok().finish()) future::ok(HttpResponse::Ok().finish())
} }
#[get("/test")] #[get("/test")]
fn auto_sync() -> impl Future<Output = Result<HttpResponse, actix_web::Error>> { fn auto_sync() -> impl Future<Item = HttpResponse, Error = actix_web::Error> {
future::ok(HttpResponse::Ok().finish()) future::ok(HttpResponse::Ok().finish())
} }
#[put("/test/{param}")] #[put("/test/{param}")]
async fn put_param_test(_: Path<String>) -> impl Responder { fn put_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Created() HttpResponse::Created()
} }
#[delete("/test/{param}")] #[delete("/test/{param}")]
async fn delete_param_test(_: Path<String>) -> impl Responder { fn delete_param_test(_: Path<String>) -> impl Responder {
HttpResponse::NoContent() HttpResponse::NoContent()
} }
#[get("/test/{param}")] #[get("/test/{param}")]
async fn get_param_test(_: Path<String>) -> impl Responder { fn get_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[test] #[test]
fn test_params() { fn test_params() {
block_on(async { let mut srv = TestServer::new(|| {
let srv = TestServer::start(|| { HttpService::new(
HttpService::new( App::new()
App::new() .service(get_param_test)
.service(get_param_test) .service(put_param_test)
.service(put_param_test) .service(delete_param_test),
.service(delete_param_test), )
) });
});
let request = srv.request(http::Method::GET, srv.url("/test/it")); let request = srv.request(http::Method::GET, srv.url("/test/it"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert_eq!(response.status(), http::StatusCode::OK); assert_eq!(response.status(), http::StatusCode::OK);
let request = srv.request(http::Method::PUT, srv.url("/test/it")); let request = srv.request(http::Method::PUT, srv.url("/test/it"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert_eq!(response.status(), http::StatusCode::CREATED); assert_eq!(response.status(), http::StatusCode::CREATED);
let request = srv.request(http::Method::DELETE, srv.url("/test/it")); let request = srv.request(http::Method::DELETE, srv.url("/test/it"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert_eq!(response.status(), http::StatusCode::NO_CONTENT); assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
})
} }
#[test] #[test]
fn test_body() { fn test_body() {
block_on(async { let mut srv = TestServer::new(|| {
let srv = TestServer::start(|| { HttpService::new(
HttpService::new( App::new()
App::new() .service(post_test)
.service(post_test) .service(put_test)
.service(put_test) .service(head_test)
.service(head_test) .service(connect_test)
.service(connect_test) .service(options_test)
.service(options_test) .service(trace_test)
.service(trace_test) .service(patch_test)
.service(patch_test) .service(test),
.service(test), )
) });
}); let request = srv.request(http::Method::GET, srv.url("/test"));
let request = srv.request(http::Method::GET, srv.url("/test")); let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap(); assert!(response.status().is_success());
assert!(response.status().is_success());
let request = srv.request(http::Method::HEAD, srv.url("/test")); let request = srv.request(http::Method::HEAD, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::CONNECT, srv.url("/test")); let request = srv.request(http::Method::CONNECT, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::OPTIONS, srv.url("/test")); let request = srv.request(http::Method::OPTIONS, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::TRACE, srv.url("/test")); let request = srv.request(http::Method::TRACE, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::PATCH, srv.url("/test")); let request = srv.request(http::Method::PATCH, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::PUT, srv.url("/test")); let request = srv.request(http::Method::PUT, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.status(), http::StatusCode::CREATED); assert_eq!(response.status(), http::StatusCode::CREATED);
let request = srv.request(http::Method::POST, srv.url("/test")); let request = srv.request(http::Method::POST, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.status(), http::StatusCode::NO_CONTENT); assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
let request = srv.request(http::Method::GET, srv.url("/test")); let request = srv.request(http::Method::GET, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
})
} }
#[test] #[test]
fn test_auto_async() { fn test_auto_async() {
block_on(async { let mut srv = TestServer::new(|| HttpService::new(App::new().service(auto_async)));
let srv = TestServer::start(|| HttpService::new(App::new().service(auto_async)));
let request = srv.request(http::Method::GET, srv.url("/test")); let request = srv.request(http::Method::GET, srv.url("/test"));
let response = request.send().await.unwrap(); let response = srv.block_on(request.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
})
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "awc" name = "awc"
version = "0.3.0-alpha.1" version = "0.2.8"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http client." description = "Actix http client."
readme = "README.md" readme = "README.md"
@@ -21,16 +21,16 @@ name = "awc"
path = "src/lib.rs" path = "src/lib.rs"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "brotli", "flate2-zlib"] features = ["ssl", "brotli", "flate2-zlib"]
[features] [features]
default = ["brotli", "flate2-zlib"] default = ["brotli", "flate2-zlib"]
# openssl # openssl
openssl = ["open-ssl", "actix-http/openssl"] ssl = ["openssl", "actix-http/ssl"]
# rustls # rustls
# rustls = ["rust-tls", "actix-http/rustls"] rust-tls = ["rustls", "actix-http/rust-tls"]
# brotli encoding, requires c compiler # brotli encoding, requires c compiler
brotli = ["actix-http/brotli"] brotli = ["actix-http/brotli"]
@@ -42,14 +42,13 @@ flate2-zlib = ["actix-http/flate2-zlib"]
flate2-rust = ["actix-http/flate2-rust"] flate2-rust = ["actix-http/flate2-rust"]
[dependencies] [dependencies]
actix-codec = "0.2.0-alpha.1" actix-codec = "0.1.2"
actix-service = "1.0.0-alpha.1" actix-service = "0.4.1"
actix-http = "0.3.0-alpha.1" actix-http = "0.2.11"
base64 = "0.10.1" base64 = "0.10.1"
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.3.1" futures = "0.1.25"
log =" 0.4" log =" 0.4"
mime = "0.3" mime = "0.3"
percent-encoding = "2.1" percent-encoding = "2.1"
@@ -57,21 +56,21 @@ rand = "0.7"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
tokio-timer = "0.3.0-alpha.6" tokio-timer = "0.2.8"
open-ssl = { version="0.10", package="openssl", optional = true } openssl = { version="0.10", optional = true }
rust-tls = { version = "0.16.0", package="rustls", optional = true, features = ["dangerous_configuration"] } rustls = { version = "0.15.2", optional = true }
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0-alpha.1" actix-rt = "0.2.2"
actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] } actix-web = { version = "1.0.8", features=["ssl"] }
actix-web = { version = "2.0.0-alpha.1", features=["openssl"] } actix-http = { version = "0.2.11", features=["ssl"] }
actix-http = { version = "0.3.0-alpha.1", features=["openssl"] } actix-http-test = { version = "0.2.0", features=["ssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } actix-utils = "0.4.1"
actix-utils = "0.5.0-alpha.1" actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] }
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
brotli2 = { version="0.3.2" } brotli2 = { version="0.3.2" }
flate2 = { version="1.0.2" } flate2 = { version="1.0.2" }
env_logger = "0.6" env_logger = "0.6"
rand = "0.7" rand = "0.7"
tokio-tcp = "0.1" tokio-tcp = "0.1"
webpki = { version = "0.21" } webpki = "0.19"
rustls = { version = "0.15.2", features = ["dangerous_configuration"] }

View File

@@ -1,6 +1,4 @@
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, io, net}; use std::{fmt, io, net};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
@@ -12,7 +10,7 @@ use actix_http::h1::ClientCodec;
use actix_http::http::HeaderMap; use actix_http::http::HeaderMap;
use actix_http::{RequestHead, RequestHeadType, ResponseHead}; use actix_http::{RequestHead, RequestHeadType, ResponseHead};
use actix_service::Service; use actix_service::Service;
use futures::future::{FutureExt, LocalBoxFuture}; use futures::{Future, Poll};
use crate::response::ClientResponse; use crate::response::ClientResponse;
@@ -24,7 +22,7 @@ pub(crate) trait Connect {
head: RequestHead, head: RequestHead,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>; ) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>;
fn send_request_extra( fn send_request_extra(
&mut self, &mut self,
@@ -32,16 +30,18 @@ pub(crate) trait Connect {
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>; ) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel( fn open_tunnel(
&mut self, &mut self,
head: RequestHead, head: RequestHead,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture< ) -> Box<
'static, dyn Future<
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
>; >;
/// Send request and extra headers, returns Response and Framed /// Send request and extra headers, returns Response and Framed
@@ -50,9 +50,11 @@ pub(crate) trait Connect {
head: Rc<RequestHead>, head: Rc<RequestHead>,
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture< ) -> Box<
'static, dyn Future<
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
>; >;
} }
@@ -70,23 +72,21 @@ where
head: RequestHead, head: RequestHead,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>> { ) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>> {
// connect to the host Box::new(
let fut = self.0.call(ClientConnect { self.0
uri: head.uri.clone(), // connect to the host
addr, .call(ClientConnect {
}); uri: head.uri.clone(),
addr,
async move { })
let connection = fut.await?; .from_err()
// send request
// send request .and_then(move |connection| {
connection connection.send_request(RequestHeadType::from(head), body)
.send_request(RequestHeadType::from(head), body) })
.await .map(|(head, payload)| ClientResponse::new(head, payload)),
.map(|(head, payload)| ClientResponse::new(head, payload)) )
}
.boxed_local()
} }
fn send_request_extra( fn send_request_extra(
@@ -95,51 +95,51 @@ where
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>> { ) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>> {
// connect to the host Box::new(
let fut = self.0.call(ClientConnect { self.0
uri: head.uri.clone(), // connect to the host
addr, .call(ClientConnect {
}); uri: head.uri.clone(),
addr,
async move { })
let connection = fut.await?; .from_err()
// send request
// send request .and_then(move |connection| {
let (head, payload) = connection connection
.send_request(RequestHeadType::Rc(head, extra_headers), body) .send_request(RequestHeadType::Rc(head, extra_headers), body)
.await?; })
.map(|(head, payload)| ClientResponse::new(head, payload)),
Ok(ClientResponse::new(head, payload)) )
}
.boxed_local()
} }
fn open_tunnel( fn open_tunnel(
&mut self, &mut self,
head: RequestHead, head: RequestHead,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture< ) -> Box<
'static, dyn Future<
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
> { > {
// connect to the host Box::new(
let fut = self.0.call(ClientConnect { self.0
uri: head.uri.clone(), // connect to the host
addr, .call(ClientConnect {
}); uri: head.uri.clone(),
addr,
async move { })
let connection = fut.await?; .from_err()
// send request
// send request .and_then(move |connection| {
let (head, framed) = connection.open_tunnel(RequestHeadType::from(head))
connection.open_tunnel(RequestHeadType::from(head)).await?; })
.map(|(head, framed)| {
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
Ok((head, framed)) (head, framed)
} }),
.boxed_local() )
} }
fn open_tunnel_extra( fn open_tunnel_extra(
@@ -147,47 +147,48 @@ where
head: Rc<RequestHead>, head: Rc<RequestHead>,
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> LocalBoxFuture< ) -> Box<
'static, dyn Future<
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
> { > {
// connect to the host Box::new(
let fut = self.0.call(ClientConnect { self.0
uri: head.uri.clone(), // connect to the host
addr, .call(ClientConnect {
}); uri: head.uri.clone(),
addr,
async move { })
let connection = fut.await?; .from_err()
// send request
// send request .and_then(move |connection| {
let (head, framed) = connection connection.open_tunnel(RequestHeadType::Rc(head, extra_headers))
.open_tunnel(RequestHeadType::Rc(head, extra_headers)) })
.await?; .map(|(head, framed)| {
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); (head, framed)
Ok((head, framed)) }),
} )
.boxed_local()
} }
} }
trait AsyncSocket { trait AsyncSocket {
fn as_read(&self) -> &(dyn AsyncRead + Unpin); fn as_read(&self) -> &dyn AsyncRead;
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin); fn as_read_mut(&mut self) -> &mut dyn AsyncRead;
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin); fn as_write(&mut self) -> &mut dyn AsyncWrite;
} }
struct Socket<T: AsyncRead + AsyncWrite + Unpin>(T); struct Socket<T: AsyncRead + AsyncWrite>(T);
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncSocket for Socket<T> { impl<T: AsyncRead + AsyncWrite> AsyncSocket for Socket<T> {
fn as_read(&self) -> &(dyn AsyncRead + Unpin) { fn as_read(&self) -> &dyn AsyncRead {
&self.0 &self.0
} }
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) { fn as_read_mut(&mut self) -> &mut dyn AsyncRead {
&mut self.0 &mut self.0
} }
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) { fn as_write(&mut self) -> &mut dyn AsyncWrite {
&mut self.0 &mut self.0
} }
} }
@@ -200,37 +201,30 @@ impl fmt::Debug for BoxedSocket {
} }
} }
impl io::Read for BoxedSocket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.as_read_mut().read(buf)
}
}
impl AsyncRead for BoxedSocket { impl AsyncRead for BoxedSocket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.as_read().prepare_uninitialized_buffer(buf) self.0.as_read().prepare_uninitialized_buffer(buf)
} }
}
fn poll_read( impl io::Write for BoxedSocket {
self: Pin<&mut Self>, fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
cx: &mut Context<'_>, self.0.as_write().write(buf)
buf: &mut [u8], }
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) fn flush(&mut self) -> io::Result<()> {
self.0.as_write().flush()
} }
} }
impl AsyncWrite for BoxedSocket { impl AsyncWrite for BoxedSocket {
fn poll_write( fn shutdown(&mut self) -> Poll<(), io::Error> {
self: Pin<&mut Self>, self.0.as_write().shutdown()
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx)
} }
} }

View File

@@ -82,7 +82,7 @@ impl FrozenClientRequest {
/// Send a streaming body. /// Send a streaming body.
pub fn send_stream<S, E>(&self, stream: S) -> SendClientRequest pub fn send_stream<S, E>(&self, stream: S) -> SendClientRequest
where where
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
RequestSender::Rc(self.head.clone(), None).send_stream( RequestSender::Rc(self.head.clone(), None).send_stream(
@@ -203,7 +203,7 @@ impl FrozenSendBuilder {
/// Complete request construction and send a streaming body. /// Complete request construction and send a streaming body.
pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
where where
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
if let Some(e) = self.err { if let Some(e) = self.err {

View File

@@ -7,18 +7,18 @@
//! use awc::Client; //! use awc::Client;
//! //!
//! fn main() { //! fn main() {
//! System::new("test").block_on(async { //! System::new("test").block_on(lazy(|| {
//! let mut client = Client::default(); //! let mut client = Client::default();
//! //!
//! client.get("http://www.rust-lang.org") // <- Create request builder //! client.get("http://www.rust-lang.org") // <- Create request builder
//! .header("User-Agent", "Actix-web") //! .header("User-Agent", "Actix-web")
//! .send() // <- Send http request //! .send() // <- Send http request
//! .await //! .map_err(|_| ())
//! .and_then(|response| { // <- server http response //! .and_then(|response| { // <- server http response
//! println!("Response: {:?}", response); //! println!("Response: {:?}", response);
//! Ok(()) //! Ok(())
//! }) //! })
//! }); //! }));
//! } //! }
//! ``` //! ```
use std::cell::RefCell; use std::cell::RefCell;
@@ -52,22 +52,23 @@ use self::connect::{Connect, ConnectorWrapper};
/// An HTTP Client /// An HTTP Client
/// ///
/// ```rust /// ```rust
/// # use futures::future::{Future, lazy};
/// use actix_rt::System; /// use actix_rt::System;
/// use awc::Client; /// use awc::Client;
/// ///
/// fn main() { /// fn main() {
/// System::new("test").block_on(async { /// System::new("test").block_on(lazy(|| {
/// let mut client = Client::default(); /// let mut client = Client::default();
/// ///
/// client.get("http://www.rust-lang.org") // <- Create request builder /// client.get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web") /// .header("User-Agent", "Actix-web")
/// .send() // <- Send http request /// .send() // <- Send http request
/// .await /// .map_err(|_| ())
/// .and_then(|response| { // <- server http response /// .and_then(|response| { // <- server http response
/// println!("Response: {:?}", response); /// println!("Response: {:?}", response);
/// Ok(()) /// Ok(())
/// }) /// })
/// }); /// }));
/// } /// }
/// ``` /// ```
#[derive(Clone)] #[derive(Clone)]

View File

@@ -37,21 +37,21 @@ const HTTPS_ENCODING: &str = "gzip, deflate";
/// builder-like pattern. /// builder-like pattern.
/// ///
/// ```rust /// ```rust
/// use futures::future::{Future, lazy};
/// use actix_rt::System; /// use actix_rt::System;
/// ///
/// fn main() { /// fn main() {
/// System::new("test").block_on(async { /// System::new("test").block_on(lazy(|| {
/// let response = awc::Client::new() /// awc::Client::new()
/// .get("http://www.rust-lang.org") // <- Create request builder /// .get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web") /// .header("User-Agent", "Actix-web")
/// .send() // <- Send http request /// .send() // <- Send http request
/// .await; /// .map_err(|_| ())
/// /// .and_then(|response| { // <- server http response
/// response.and_then(|response| { // <- server http response /// println!("Response: {:?}", response);
/// println!("Response: {:?}", response); /// Ok(())
/// Ok(())
/// }) /// })
/// }); /// }));
/// } /// }
/// ``` /// ```
pub struct ClientRequest { pub struct ClientRequest {
@@ -158,7 +158,7 @@ impl ClientRequest {
/// ///
/// ```rust /// ```rust
/// fn main() { /// fn main() {
/// # actix_rt::System::new("test").block_on(futures::future::lazy(|_| { /// # actix_rt::System::new("test").block_on(futures::future::lazy(|| {
/// let req = awc::Client::new() /// let req = awc::Client::new()
/// .get("http://www.rust-lang.org") /// .get("http://www.rust-lang.org")
/// .set(awc::http::header::Date::now()) /// .set(awc::http::header::Date::now())
@@ -186,13 +186,13 @@ impl ClientRequest {
/// use awc::{http, Client}; /// use awc::{http, Client};
/// ///
/// fn main() { /// fn main() {
/// # actix_rt::System::new("test").block_on(async { /// # actix_rt::System::new("test").block_on(futures::future::lazy(|| {
/// let req = Client::new() /// let req = Client::new()
/// .get("http://www.rust-lang.org") /// .get("http://www.rust-lang.org")
/// .header("X-TEST", "value") /// .header("X-TEST", "value")
/// .header(http::header::CONTENT_TYPE, "application/json"); /// .header(http::header::CONTENT_TYPE, "application/json");
/// # Ok::<_, ()>(()) /// # Ok::<_, ()>(())
/// # }); /// # }));
/// } /// }
/// ``` /// ```
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
@@ -309,8 +309,9 @@ impl ClientRequest {
/// ///
/// ```rust /// ```rust
/// # use actix_rt::System; /// # use actix_rt::System;
/// # use futures::future::{lazy, Future};
/// fn main() { /// fn main() {
/// System::new("test").block_on(async { /// System::new("test").block_on(lazy(|| {
/// awc::Client::new().get("https://www.rust-lang.org") /// awc::Client::new().get("https://www.rust-lang.org")
/// .cookie( /// .cookie(
/// awc::http::Cookie::build("name", "value") /// awc::http::Cookie::build("name", "value")
@@ -321,12 +322,12 @@ impl ClientRequest {
/// .finish(), /// .finish(),
/// ) /// )
/// .send() /// .send()
/// .await /// .map_err(|_| ())
/// .and_then(|response| { /// .and_then(|response| {
/// println!("Response: {:?}", response); /// println!("Response: {:?}", response);
/// Ok(()) /// Ok(())
/// }) /// })
/// }); /// }));
/// } /// }
/// ``` /// ```
pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
@@ -477,7 +478,7 @@ impl ClientRequest {
/// Set an streaming body and generate `ClientRequest`. /// Set an streaming body and generate `ClientRequest`.
pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
where where
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
let slf = match self.prep_for_sending() { let slf = match self.prep_for_sending() {

View File

@@ -1,11 +1,9 @@
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{ready, Future, Stream}; use futures::{Async, Future, Poll, Stream};
use actix_http::cookie::Cookie; use actix_http::cookie::Cookie;
use actix_http::error::{CookieParseError, PayloadError}; use actix_http::error::{CookieParseError, PayloadError};
@@ -106,7 +104,7 @@ impl<S> ClientResponse<S> {
impl<S> ClientResponse<S> impl<S> ClientResponse<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>>, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
/// Loads http response's body. /// Loads http response's body.
pub fn body(&mut self) -> MessageBody<S> { pub fn body(&mut self) -> MessageBody<S> {
@@ -127,12 +125,13 @@ where
impl<S> Stream for ClientResponse<S> impl<S> Stream for ClientResponse<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
Pin::new(&mut self.get_mut().payload).poll_next(cx) self.payload.poll()
} }
} }
@@ -156,7 +155,7 @@ pub struct MessageBody<S> {
impl<S> MessageBody<S> impl<S> MessageBody<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>>, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
/// Create `MessageBody` for request. /// Create `MessageBody` for request.
pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> { pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> {
@@ -199,24 +198,23 @@ where
impl<S> Future for MessageBody<S> impl<S> Future for MessageBody<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Output = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.get_mut(); if let Some(err) = self.err.take() {
return Err(err);
if let Some(err) = this.err.take() {
return Poll::Ready(Err(err));
} }
if let Some(len) = this.length.take() { if let Some(len) = self.length.take() {
if len > this.fut.as_ref().unwrap().limit { if len > self.fut.as_ref().unwrap().limit {
return Poll::Ready(Err(PayloadError::Overflow)); return Err(PayloadError::Overflow);
} }
} }
Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) self.fut.as_mut().unwrap().poll()
} }
} }
@@ -235,7 +233,7 @@ pub struct JsonBody<S, U> {
impl<S, U> JsonBody<S, U> impl<S, U> JsonBody<S, U>
where where
S: Stream<Item = Result<Bytes, PayloadError>>, S: Stream<Item = Bytes, Error = PayloadError>,
U: DeserializeOwned, U: DeserializeOwned,
{ {
/// Create `JsonBody` for request. /// Create `JsonBody` for request.
@@ -281,35 +279,27 @@ where
} }
} }
impl<T, U> Unpin for JsonBody<T, U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
}
impl<T, U> Future for JsonBody<T, U> impl<T, U> Future for JsonBody<T, U>
where where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin, T: Stream<Item = Bytes, Error = PayloadError>,
U: DeserializeOwned, U: DeserializeOwned,
{ {
type Output = Result<U, JsonPayloadError>; type Item = U;
type Error = JsonPayloadError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<U, JsonPayloadError> {
if let Some(err) = self.err.take() { if let Some(err) = self.err.take() {
return Poll::Ready(Err(err)); return Err(err);
} }
if let Some(len) = self.length.take() { if let Some(len) = self.length.take() {
if len > self.fut.as_ref().unwrap().limit { if len > self.fut.as_ref().unwrap().limit {
return Poll::Ready(Err(JsonPayloadError::Payload( return Err(JsonPayloadError::Payload(PayloadError::Overflow));
PayloadError::Overflow,
)));
} }
} }
let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; let body = futures::try_ready!(self.fut.as_mut().unwrap().poll());
Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from)) Ok(Async::Ready(serde_json::from_slice::<U>(&body)?))
} }
} }
@@ -331,25 +321,24 @@ impl<S> ReadBody<S> {
impl<S> Future for ReadBody<S> impl<S> Future for ReadBody<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Output = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
return match Pin::new(&mut this.stream).poll_next(cx)? { return match self.stream.poll()? {
Poll::Ready(Some(chunk)) => { Async::Ready(Some(chunk)) => {
if (this.buf.len() + chunk.len()) > this.limit { if (self.buf.len() + chunk.len()) > self.limit {
Poll::Ready(Err(PayloadError::Overflow)) Err(PayloadError::Overflow)
} else { } else {
this.buf.extend_from_slice(&chunk); self.buf.extend_from_slice(&chunk);
continue; continue;
} }
} }
Poll::Ready(None) => Poll::Ready(Ok(this.buf.take().freeze())), Async::Ready(None) => Ok(Async::Ready(self.buf.take().freeze())),
Poll::Pending => Poll::Pending, Async::NotReady => Ok(Async::NotReady),
}; };
} }
} }
@@ -359,40 +348,41 @@ where
mod tests { mod tests {
use super::*; use super::*;
use actix_http_test::block_on; use actix_http_test::block_on;
use futures::Async;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{http::header, test::TestResponse}; use crate::{http::header, test::TestResponse};
#[test] #[test]
fn test_body() { fn test_body() {
block_on(async { let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish();
let mut req = match req.body().poll().err().unwrap() {
TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); PayloadError::UnknownLength => (),
match req.body().await.err().unwrap() { _ => unreachable!("error"),
PayloadError::UnknownLength => (), }
_ => unreachable!("error"),
}
let mut req = let mut req =
TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish();
match req.body().await.err().unwrap() { match req.body().poll().err().unwrap() {
PayloadError::Overflow => (), PayloadError::Overflow => (),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
let mut req = TestResponse::default() let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"test")) .set_payload(Bytes::from_static(b"test"))
.finish(); .finish();
assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); match req.body().poll().ok().unwrap() {
Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")),
_ => unreachable!("error"),
}
let mut req = TestResponse::default() let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"11111111111111")) .set_payload(Bytes::from_static(b"11111111111111"))
.finish(); .finish();
match req.body().limit(5).await.err().unwrap() { match req.body().limit(5).poll().err().unwrap() {
PayloadError::Overflow => (), PayloadError::Overflow => (),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
})
} }
#[derive(Serialize, Deserialize, PartialEq, Debug)] #[derive(Serialize, Deserialize, PartialEq, Debug)]
@@ -416,56 +406,54 @@ mod tests {
#[test] #[test]
fn test_json_body() { fn test_json_body() {
block_on(async { let mut req = TestResponse::default().finish();
let mut req = TestResponse::default().finish(); let json = block_on(JsonBody::<_, MyObject>::new(&mut req));
let json = JsonBody::<_, MyObject>::new(&mut req).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default() let mut req = TestResponse::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"), header::HeaderValue::from_static("application/text"),
) )
.finish(); .finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await; let json = block_on(JsonBody::<_, MyObject>::new(&mut req));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default() let mut req = TestResponse::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),
) )
.header( .header(
header::CONTENT_LENGTH, header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"), header::HeaderValue::from_static("10000"),
) )
.finish(); .finish();
let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; let json = block_on(JsonBody::<_, MyObject>::new(&mut req).limit(100));
assert!(json_eq( assert!(json_eq(
json.err().unwrap(), json.err().unwrap(),
JsonPayloadError::Payload(PayloadError::Overflow) JsonPayloadError::Payload(PayloadError::Overflow)
)); ));
let mut req = TestResponse::default() let mut req = TestResponse::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),
) )
.header( .header(
header::CONTENT_LENGTH, header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"), header::HeaderValue::from_static("16"),
) )
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.finish(); .finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await; let json = block_on(JsonBody::<_, MyObject>::new(&mut req));
assert_eq!( assert_eq!(
json.ok().unwrap(), json.ok().unwrap(),
MyObject { MyObject {
name: "test".to_owned() name: "test".to_owned()
} }
); );
})
} }
} }

View File

@@ -1,15 +1,13 @@
use std::net; use std::net;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll}; use std::time::{Duration, Instant};
use std::time::Duration;
use bytes::Bytes; use bytes::Bytes;
use derive_more::From; use derive_more::From;
use futures::{future::LocalBoxFuture, ready, Future, Stream}; use futures::{try_ready, Async, Future, Poll, Stream};
use serde::Serialize; use serde::Serialize;
use serde_json; use serde_json;
use tokio_timer::{delay_for, Delay}; use tokio_timer::Delay;
use actix_http::body::{Body, BodyStream}; use actix_http::body::{Body, BodyStream};
use actix_http::encoding::Decoder; use actix_http::encoding::Decoder;
@@ -49,7 +47,7 @@ impl Into<SendRequestError> for PrepForSendingError {
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
pub enum SendClientRequest { pub enum SendClientRequest {
Fut( Fut(
LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>, Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>,
Option<Delay>, Option<Delay>,
bool, bool,
), ),
@@ -58,51 +56,41 @@ pub enum SendClientRequest {
impl SendClientRequest { impl SendClientRequest {
pub(crate) fn new( pub(crate) fn new(
send: LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>, send: Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>,
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
) -> SendClientRequest { ) -> SendClientRequest {
let delay = timeout.map(|t| delay_for(t)); let delay = timeout.map(|t| Delay::new(Instant::now() + t));
SendClientRequest::Fut(send, delay, response_decompress) SendClientRequest::Fut(send, delay, response_decompress)
} }
} }
impl Future for SendClientRequest { impl Future for SendClientRequest {
type Output = type Item = ClientResponse<Decoder<Payload<PayloadStream>>>;
Result<ClientResponse<Decoder<Payload<PayloadStream>>>, SendRequestError>; type Error = SendRequestError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.get_mut(); match self {
match this {
SendClientRequest::Fut(send, delay, response_decompress) => { SendClientRequest::Fut(send, delay, response_decompress) => {
if delay.is_some() { if delay.is_some() {
match Pin::new(delay.as_mut().unwrap()).poll(cx) { match delay.poll() {
Poll::Pending => (), Ok(Async::NotReady) => (),
_ => return Poll::Ready(Err(SendRequestError::Timeout)), _ => return Err(SendRequestError::Timeout),
} }
} }
let res = ready!(Pin::new(send).poll(cx)).map(|res| { let res = try_ready!(send.poll()).map_body(|head, payload| {
res.map_body(|head, payload| { if *response_decompress {
if *response_decompress { Payload::Stream(Decoder::from_headers(payload, &head.headers))
Payload::Stream(Decoder::from_headers( } else {
payload, Payload::Stream(Decoder::new(payload, ContentEncoding::Identity))
&head.headers, }
))
} else {
Payload::Stream(Decoder::new(
payload,
ContentEncoding::Identity,
))
}
})
}); });
Poll::Ready(res) Ok(Async::Ready(res))
} }
SendClientRequest::Err(ref mut e) => match e.take() { SendClientRequest::Err(ref mut e) => match e.take() {
Some(e) => Poll::Ready(Err(e)), Some(e) => Err(e),
None => panic!("Attempting to call completed future"), None => panic!("Attempting to call completed future"),
}, },
} }
@@ -235,7 +223,7 @@ impl RequestSender {
stream: S, stream: S,
) -> SendClientRequest ) -> SendClientRequest
where where
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
self.send_body( self.send_body(

View File

@@ -7,6 +7,7 @@ use std::{fmt, str};
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::cookie::{Cookie, CookieJar}; use actix_http::cookie::{Cookie, CookieJar};
use actix_http::{ws, Payload, RequestHead}; use actix_http::{ws, Payload, RequestHead};
use futures::future::{err, Either, Future};
use percent_encoding::percent_encode; use percent_encoding::percent_encode;
use tokio_timer::Timeout; use tokio_timer::Timeout;
@@ -209,26 +210,27 @@ impl WebsocketsRequest {
} }
/// Complete request construction and connect to a websockets server. /// Complete request construction and connect to a websockets server.
pub async fn connect( pub fn connect(
mut self, mut self,
) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> { ) -> impl Future<Item = (ClientResponse, Framed<BoxedSocket, Codec>), Error = WsClientError>
{
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
return Err(e.into()); return Either::A(err(e.into()));
} }
// validate uri // validate uri
let uri = &self.head.uri; let uri = &self.head.uri;
if uri.host().is_none() { if uri.host().is_none() {
return Err(InvalidUrl::MissingHost.into()); return Either::A(err(InvalidUrl::MissingHost.into()));
} else if uri.scheme_part().is_none() { } else if uri.scheme_part().is_none() {
return Err(InvalidUrl::MissingScheme.into()); return Either::A(err(InvalidUrl::MissingScheme.into()));
} else if let Some(scheme) = uri.scheme_part() { } else if let Some(scheme) = uri.scheme_part() {
match scheme.as_str() { match scheme.as_str() {
"http" | "ws" | "https" | "wss" => (), "http" | "ws" | "https" | "wss" => (),
_ => return Err(InvalidUrl::UnknownScheme.into()), _ => return Either::A(err(InvalidUrl::UnknownScheme.into())),
} }
} else { } else {
return Err(InvalidUrl::UnknownScheme.into()); return Either::A(err(InvalidUrl::UnknownScheme.into()));
} }
if !self.head.headers.contains_key(header::HOST) { if !self.head.headers.contains_key(header::HOST) {
@@ -292,83 +294,90 @@ impl WebsocketsRequest {
.config .config
.connector .connector
.borrow_mut() .borrow_mut()
.open_tunnel(head, self.addr); .open_tunnel(head, self.addr)
.from_err()
.and_then(move |(head, framed)| {
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(head.status));
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
s.to_ascii_lowercase().contains("websocket")
} else {
false
}
} else {
false
};
if !has_hdr {
log::trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = head.headers.get(&header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_ascii_lowercase().contains("upgrade") {
log::trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(
conn.clone(),
));
}
} else {
log::trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(
conn.clone(),
));
}
} else {
log::trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader);
}
if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
let encoded = ws::hash_key(key.as_ref());
if hdr_key.as_bytes() != encoded.as_bytes() {
log::trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded,
key
);
return Err(WsClientError::InvalidChallengeResponse(
encoded,
hdr_key.clone(),
));
}
} else {
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader);
};
// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
framed.map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
} else {
ws::Codec::new().max_size(max_size).client_mode()
}
}),
))
});
// set request timeout // set request timeout
let (head, framed) = if let Some(timeout) = self.config.timeout { if let Some(timeout) = self.config.timeout {
Timeout::new(fut, timeout) Either::B(Either::A(Timeout::new(fut, timeout).map_err(|e| {
.await if let Some(e) = e.into_inner() {
.map_err(|_| SendRequestError::Timeout.into()) e
.and_then(|res| res)?
} else {
fut.await?
};
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(head.status));
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
s.to_ascii_lowercase().contains("websocket")
} else {
false
}
} else {
false
};
if !has_hdr {
log::trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = head.headers.get(&header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_ascii_lowercase().contains("upgrade") {
log::trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader);
}
if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
let encoded = ws::hash_key(key.as_ref());
if hdr_key.as_bytes() != encoded.as_bytes() {
log::trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded,
key
);
return Err(WsClientError::InvalidChallengeResponse(
encoded,
hdr_key.clone(),
));
}
} else {
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader);
};
// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
framed.map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
} else { } else {
ws::Codec::new().max_size(max_size).client_mode() SendRequestError::Timeout.into()
} }
}), })))
)) } else {
Either::B(Either::B(fut))
}
} }
} }
@@ -389,8 +398,6 @@ impl fmt::Debug for WebsocketsRequest {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_web::test::block_on;
use super::*; use super::*;
use crate::Client; use crate::Client;
@@ -465,33 +472,35 @@ mod tests {
#[test] #[test]
fn basics() { fn basics() {
block_on(async { let req = Client::new()
let req = Client::new() .ws("http://localhost/")
.ws("http://localhost/") .origin("test-origin")
.origin("test-origin") .max_frame_size(100)
.max_frame_size(100) .server_mode()
.server_mode() .protocols(&["v1", "v2"])
.protocols(&["v1", "v2"]) .set_header_if_none(header::CONTENT_TYPE, "json")
.set_header_if_none(header::CONTENT_TYPE, "json") .set_header_if_none(header::CONTENT_TYPE, "text")
.set_header_if_none(header::CONTENT_TYPE, "text") .cookie(Cookie::build("cookie1", "value1").finish());
.cookie(Cookie::build("cookie1", "value1").finish()); assert_eq!(
assert_eq!( req.origin.as_ref().unwrap().to_str().unwrap(),
req.origin.as_ref().unwrap().to_str().unwrap(), "test-origin"
"test-origin" );
); assert_eq!(req.max_size, 100);
assert_eq!(req.max_size, 100); assert_eq!(req.server_mode, true);
assert_eq!(req.server_mode, true); assert_eq!(req.protocols, Some("v1,v2".to_string()));
assert_eq!(req.protocols, Some("v1,v2".to_string())); assert_eq!(
assert_eq!( req.head.headers.get(header::CONTENT_TYPE).unwrap(),
req.head.headers.get(header::CONTENT_TYPE).unwrap(), header::HeaderValue::from_static("json")
header::HeaderValue::from_static("json") );
);
let _ = req.connect().await; let _ = actix_http_test::block_fn(move || req.connect());
assert!(Client::new().ws("/").connect().await.is_err()); assert!(Client::new().ws("/").connect().poll().is_err());
assert!(Client::new().ws("http:///test").connect().await.is_err()); assert!(Client::new().ws("http:///test").connect().poll().is_err());
assert!(Client::new().ws("hmm://test.com/").connect().await.is_err()); assert!(Client::new()
}) .ws("hmm://test.com/")
.connect()
.poll()
.is_err());
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,109 +1,96 @@
#![cfg(feature = "rustls")] #![cfg(feature = "rust-tls")]
use rust_tls::ClientConfig; use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
ClientConfig, NoClientAuth,
};
use std::io::Result; use std::fs::File;
use std::io::{BufReader, Result};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::HttpService; use actix_http::HttpService;
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
use actix_server::ssl::OpensslAcceptor; use actix_server::ssl::RustlsAcceptor;
use actix_service::{pipeline_factory, ServiceFactory}; use actix_service::{service_fn, NewService};
use actix_web::http::Version; use actix_web::http::Version;
use actix_web::{web, App, HttpResponse}; use actix_web::{web, App, HttpResponse};
use futures::future::ok;
use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode};
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> { fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> {
use rustls::ServerConfig;
// load ssl keys // load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let mut config = ServerConfig::new(NoClientAuth::new());
builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true); let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
builder let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
.set_private_key_file("../tests/key.pem", SslFiletype::PEM) let cert_chain = certs(cert_file).unwrap();
.unwrap(); let mut keys = pkcs8_private_keys(key_file).unwrap();
builder config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
.set_certificate_chain_file("../tests/cert.pem") let protos = vec![b"h2".to_vec()];
.unwrap(); config.set_protocols(&protos);
builder.set_alpn_select_callback(|_, protos| { Ok(RustlsAcceptor::new(config))
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(open_ssl::ssl::AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(actix_server::ssl::OpensslAcceptor::new(builder.build()))
} }
mod danger { mod danger {
pub struct NoCertificateVerification {} pub struct NoCertificateVerification {}
impl rust_tls::ServerCertVerifier for NoCertificateVerification { impl rustls::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert( fn verify_server_cert(
&self, &self,
_roots: &rust_tls::RootCertStore, _roots: &rustls::RootCertStore,
_presented_certs: &[rust_tls::Certificate], _presented_certs: &[rustls::Certificate],
_dns_name: webpki::DNSNameRef<'_>, _dns_name: webpki::DNSNameRef<'_>,
_ocsp: &[u8], _ocsp: &[u8],
) -> Result<rust_tls::ServerCertVerified, rust_tls::TLSError> { ) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
Ok(rust_tls::ServerCertVerified::assertion()) Ok(rustls::ServerCertVerified::assertion())
} }
} }
} }
// #[test] #[test]
fn _test_connection_reuse_h2() { fn test_connection_reuse_h2() {
block_on(async { let rustls = ssl_acceptor().unwrap();
let openssl = ssl_acceptor().unwrap(); let num = Arc::new(AtomicUsize::new(0));
let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone();
let num2 = num.clone();
let srv = TestServer::start(move || { let mut srv = TestServer::new(move || {
let num2 = num2.clone(); let num2 = num2.clone();
pipeline_factory(move |io| { service_fn(move |io| {
num2.fetch_add(1, Ordering::Relaxed); num2.fetch_add(1, Ordering::Relaxed);
ok(io) Ok(io)
}) })
.and_then( .and_then(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
openssl .and_then(
.clone() HttpService::build()
.map_err(|e| println!("Openssl error: {}", e)), .h2(App::new()
) .service(web::resource("/").route(web::to(|| HttpResponse::Ok()))))
.and_then( .map_err(|_| ()),
HttpService::build() )
.h2(App::new().service( });
web::resource("/").route(web::to(|| HttpResponse::Ok())),
))
.map_err(|_| ()),
)
});
// disable ssl verification // disable ssl verification
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
config.set_protocols(&protos); config.set_protocols(&protos);
config config
.dangerous() .dangerous()
.set_certificate_verifier(Arc::new(danger::NoCertificateVerification {})); .set_certificate_verifier(Arc::new(danger::NoCertificateVerification {}));
let client = awc::Client::build() let client = awc::Client::build()
.connector(awc::Connector::new().rustls(Arc::new(config)).finish()) .connector(awc::Connector::new().rustls(Arc::new(config)).finish())
.finish(); .finish();
// req 1 // req 1
let request = client.get(srv.surl("/")).send(); let request = client.get(srv.surl("/")).send();
let response = request.await.unwrap(); let response = srv.block_on(request).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// req 2 // req 2
let req = client.post(srv.surl("/")); let req = client.post(srv.surl("/"));
let response = req.send().await.unwrap(); let response = srv.block_on_fn(move || req.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2); assert_eq!(response.version(), Version::HTTP_2);
// one connection // one connection
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
})
} }

View File

@@ -1,5 +1,5 @@
#![cfg(feature = "openssl")] #![cfg(feature = "ssl")]
use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
use std::io::Result; use std::io::Result;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@@ -7,12 +7,11 @@ use std::sync::Arc;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::HttpService; use actix_http::HttpService;
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
use actix_server::ssl::OpensslAcceptor; use actix_server::ssl::OpensslAcceptor;
use actix_service::{pipeline_factory, ServiceFactory}; use actix_service::{service_fn, NewService};
use actix_web::http::Version; use actix_web::http::Version;
use actix_web::{web, App, HttpResponse}; use actix_web::{web, App, HttpResponse};
use futures::future::ok;
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> { fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys // load ssl keys
@@ -28,7 +27,7 @@ fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
if protos.windows(3).any(|window| window == H2) { if protos.windows(3).any(|window| window == H2) {
Ok(b"h2") Ok(b"h2")
} else { } else {
Err(open_ssl::ssl::AlpnError::NOACK) Err(openssl::ssl::AlpnError::NOACK)
} }
}); });
builder.set_alpn_protos(b"\x02h2")?; builder.set_alpn_protos(b"\x02h2")?;
@@ -37,54 +36,51 @@ fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
#[test] #[test]
fn test_connection_reuse_h2() { fn test_connection_reuse_h2() {
block_on(async { let openssl = ssl_acceptor().unwrap();
let openssl = ssl_acceptor().unwrap(); let num = Arc::new(AtomicUsize::new(0));
let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone();
let num2 = num.clone();
let srv = TestServer::start(move || { let mut srv = TestServer::new(move || {
let num2 = num2.clone(); let num2 = num2.clone();
pipeline_factory(move |io| { service_fn(move |io| {
num2.fetch_add(1, Ordering::Relaxed); num2.fetch_add(1, Ordering::Relaxed);
ok(io) Ok(io)
}) })
.and_then( .and_then(
openssl openssl
.clone() .clone()
.map_err(|e| println!("Openssl error: {}", e)), .map_err(|e| println!("Openssl error: {}", e)),
) )
.and_then( .and_then(
HttpService::build() HttpService::build()
.h2(App::new().service( .h2(App::new()
web::resource("/").route(web::to(|| HttpResponse::Ok())), .service(web::resource("/").route(web::to(|| HttpResponse::Ok()))))
)) .map_err(|_| ()),
.map_err(|_| ()), )
) });
});
// disable ssl verification // disable ssl verification
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE); builder.set_verify(SslVerifyMode::NONE);
let _ = builder let _ = builder
.set_alpn_protos(b"\x02h2\x08http/1.1") .set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e));
let client = awc::Client::build() let client = awc::Client::build()
.connector(awc::Connector::new().ssl(builder.build()).finish()) .connector(awc::Connector::new().ssl(builder.build()).finish())
.finish(); .finish();
// req 1 // req 1
let request = client.get(srv.surl("/")).send(); let request = client.get(srv.surl("/")).send();
let response = request.await.unwrap(); let response = srv.block_on(request).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// req 2 // req 2
let req = client.post(srv.surl("/")); let req = client.post(srv.surl("/"));
let response = req.send().await.unwrap(); let response = srv.block_on_fn(move || req.send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2); assert_eq!(response.version(), Version::HTTP_2);
// one connection // one connection
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
})
} }

View File

@@ -2,82 +2,81 @@ use std::io;
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response}; use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::{block_on, TestServer}; use actix_http_test::TestServer;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::future::ok; use futures::future::ok;
use futures::{SinkExt, StreamExt}; use futures::{Future, Sink, Stream};
async fn ws_service(req: ws::Frame) -> Result<ws::Message, io::Error> { fn ws_service(req: ws::Frame) -> impl Future<Item = ws::Message, Error = io::Error> {
match req { match req {
ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)), ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
let text = if let Some(pl) = text { let text = if let Some(pl) = text {
String::from_utf8(Vec::from(pl.as_ref())).unwrap() String::from_utf8(Vec::from(pl.as_ref())).unwrap()
} else { } else {
String::new() String::new()
}; };
Ok(ws::Message::Text(text)) ok(ws::Message::Text(text))
} }
ws::Frame::Binary(bin) => Ok(ws::Message::Binary( ws::Frame::Binary(bin) => ok(ws::Message::Binary(
bin.map(|e| e.freeze()) bin.map(|e| e.freeze())
.unwrap_or_else(|| Bytes::from("")) .unwrap_or_else(|| Bytes::from(""))
.into(), .into(),
)), )),
ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)), ws::Frame::Close(reason) => ok(ws::Message::Close(reason)),
_ => Ok(ws::Message::Close(None)), _ => ok(ws::Message::Close(None)),
} }
} }
#[test] #[test]
fn test_simple() { fn test_simple() {
block_on(async { let mut srv = TestServer::new(|| {
let mut srv = TestServer::start(|| { HttpService::build()
HttpService::build() .upgrade(|(req, framed): (Request, Framed<_, _>)| {
.upgrade(|(req, mut framed): (Request, Framed<_, _>)| { let res = ws::handshake_response(req.head()).finish();
async move { // send handshake response
let res = ws::handshake_response(req.head()).finish(); framed
// send handshake response .send(h1::Message::Item((res.drop_body(), BodySize::None)))
framed .map_err(|e: io::Error| e.into())
.send(h1::Message::Item((res.drop_body(), BodySize::None))) .and_then(|framed| {
.await?;
// start websocket service // start websocket service
let framed = framed.into_framed(ws::Codec::new()); let framed = framed.into_framed(ws::Codec::new());
ws::Transport::with(framed, ws_service).await ws::Transport::with(framed, ws_service)
} })
}) })
.finish(|_| ok::<_, Error>(Response::NotFound())) .finish(|_| ok::<_, Error>(Response::NotFound()))
}); });
// client service // client service
let mut framed = srv.ws().await.unwrap(); let framed = srv.ws().unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
assert_eq!(item, ws::Frame::Text(Some(BytesMut::from("text"))));
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Binary("text".into())))
.await .unwrap();
.unwrap(); let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let item = framed.next().await.unwrap().unwrap(); assert_eq!(
assert_eq!( item,
item, Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) );
);
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let item = framed.next().await.unwrap().unwrap(); .block_on(framed.send(ws::Message::Ping("text".into())))
assert_eq!(item, ws::Frame::Pong("text".to_string().into())); .unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
framed let framed = srv
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.await .unwrap();
.unwrap();
let item = framed.next().await.unwrap().unwrap(); let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); assert_eq!(
}) item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
);
} }

View File

@@ -1,18 +1,22 @@
use actix_web::{get, middleware, web, App, HttpRequest, HttpResponse, HttpServer}; use futures::IntoFuture;
use actix_web::{
get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer,
};
#[get("/resource1/{name}/index.html")] #[get("/resource1/{name}/index.html")]
async fn index(req: HttpRequest, name: web::Path<String>) -> String { fn index(req: HttpRequest, name: web::Path<String>) -> String {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
format!("Hello: {}!\r\n", name) format!("Hello: {}!\r\n", name)
} }
async fn index_async(req: HttpRequest) -> &'static str { fn index_async(req: HttpRequest) -> impl IntoFuture<Item = &'static str, Error = Error> {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
"Hello world!\r\n" Ok("Hello world!\r\n")
} }
#[get("/")] #[get("/")]
async fn no_params() -> &'static str { fn no_params() -> &'static str {
"Hello world!\r\n" "Hello world!\r\n"
} }
@@ -35,9 +39,9 @@ fn main() -> std::io::Result<()> {
.default_service( .default_service(
web::route().to(|| HttpResponse::MethodNotAllowed()), web::route().to(|| HttpResponse::MethodNotAllowed()),
) )
.route(web::get().to(index_async)), .route(web::get().to_async(index_async)),
) )
.service(web::resource("/test1.html").to(|| async { "Test\r\n" })) .service(web::resource("/test1.html").to(|| "Test\r\n"))
}) })
.bind("127.0.0.1:8080")? .bind("127.0.0.1:8080")?
.workers(1) .workers(1)

View File

@@ -1,27 +1,26 @@
use actix_http::Error; use actix_http::Error;
use actix_rt::System; use actix_rt::System;
use futures::{future::lazy, Future};
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
std::env::set_var("RUST_LOG", "actix_http=trace"); std::env::set_var("RUST_LOG", "actix_http=trace");
env_logger::init(); env_logger::init();
System::new("test").block_on(async { System::new("test").block_on(lazy(|| {
let client = awc::Client::new(); awc::Client::new()
.get("https://www.rust-lang.org/") // <- Create request builder
// Create request builder, configure request and send
let mut response = client
.get("https://www.rust-lang.org/")
.header("User-Agent", "Actix-web") .header("User-Agent", "Actix-web")
.send() .send() // <- Send http request
.await?; .from_err()
.and_then(|mut response| {
// <- server http response
println!("Response: {:?}", response);
// server http response // read response body
println!("Response: {:?}", response); response
.body()
// read response body .from_err()
let body = response.body().await?; .map(|body| println!("Downloaded: {:?} bytes", body.len()))
println!("Downloaded: {:?} bytes", body.len()); })
}))
Ok(())
})
} }

View File

@@ -1,24 +1,26 @@
use futures::IntoFuture;
use actix_web::{ use actix_web::{
get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer,
}; };
#[get("/resource1/{name}/index.html")] #[get("/resource1/{name}/index.html")]
async fn index(req: HttpRequest, name: web::Path<String>) -> String { fn index(req: HttpRequest, name: web::Path<String>) -> String {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
format!("Hello: {}!\r\n", name) format!("Hello: {}!\r\n", name)
} }
async fn index_async(req: HttpRequest) -> Result<&'static str, Error> { fn index_async(req: HttpRequest) -> impl IntoFuture<Item = &'static str, Error = Error> {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
Ok("Hello world!\r\n") Ok("Hello world!\r\n")
} }
#[get("/")] #[get("/")]
async fn no_params() -> &'static str { fn no_params() -> &'static str {
"Hello world!\r\n" "Hello world!\r\n"
} }
#[cfg(unix)] #[cfg(feature = "uds")]
fn main() -> std::io::Result<()> { fn main() -> std::io::Result<()> {
std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info");
env_logger::init(); env_logger::init();
@@ -38,14 +40,14 @@ fn main() -> std::io::Result<()> {
.default_service( .default_service(
web::route().to(|| HttpResponse::MethodNotAllowed()), web::route().to(|| HttpResponse::MethodNotAllowed()),
) )
.route(web::get().to(index_async)), .route(web::get().to_async(index_async)),
) )
.service(web::resource("/test1.html").to(|| async { "Test\r\n" })) .service(web::resource("/test1.html").to(|| "Test\r\n"))
}) })
.bind_uds("/Users/fafhrd91/uds-test")? .bind_uds("/Users/fafhrd91/uds-test")?
.workers(1) .workers(1)
.run() .run()
} }
#[cfg(not(unix))] #[cfg(not(feature = "uds"))]
fn main() {} fn main() {}

View File

@@ -1,17 +1,14 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_http::body::{Body, MessageBody}; use actix_http::body::{Body, MessageBody};
use actix_service::boxed::{self, BoxedNewService}; use actix_service::boxed::{self, BoxedNewService};
use actix_service::{ use actix_service::{
apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform, apply_transform, IntoNewService, IntoTransform, NewService, Transform,
}; };
use futures::future::{FutureExt, LocalBoxFuture}; use futures::{Future, IntoFuture};
use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; use crate::app_service::{AppEntry, AppInit, AppRoutingFactory};
use crate::config::{AppConfig, AppConfigInner, ServiceConfig}; use crate::config::{AppConfig, AppConfigInner, ServiceConfig};
@@ -21,19 +18,19 @@ use crate::error::Error;
use crate::resource::Resource; use crate::resource::Resource;
use crate::route::Route; use crate::route::Route;
use crate::service::{ use crate::service::{
AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest,
ServiceResponse, ServiceResponse,
}; };
type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>;
type FnDataFactory = type FnDataFactory =
Box<dyn Fn() -> LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>; Box<dyn Fn() -> Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>;
/// Application builder - structure that follows the builder pattern /// Application builder - structure that follows the builder pattern
/// for building application instances. /// for building application instances.
pub struct App<T, B> { pub struct App<T, B> {
endpoint: T, endpoint: T,
services: Vec<Box<dyn AppServiceFactory>>, services: Vec<Box<dyn ServiceFactory>>,
default: Option<Rc<HttpNewService>>, default: Option<Rc<HttpNewService>>,
factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>, factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>,
data: Vec<Box<dyn DataFactory>>, data: Vec<Box<dyn DataFactory>>,
@@ -64,7 +61,7 @@ impl App<AppEntry, Body> {
impl<T, B> App<T, B> impl<T, B> App<T, B>
where where
B: MessageBody, B: MessageBody,
T: ServiceFactory< T: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -90,7 +87,7 @@ where
/// counter: Cell<usize>, /// counter: Cell<usize>,
/// } /// }
/// ///
/// async fn index(data: web::Data<MyData>) { /// fn index(data: web::Data<MyData>) {
/// data.counter.set(data.counter.get() + 1); /// data.counter.set(data.counter.get() + 1);
/// } /// }
/// ///
@@ -110,30 +107,24 @@ where
/// Set application data factory. This function is /// Set application data factory. This function is
/// similar to `.data()` but it accepts data factory. Data object get /// similar to `.data()` but it accepts data factory. Data object get
/// constructed asynchronously during application initialization. /// constructed asynchronously during application initialization.
pub fn data_factory<F, Out, D, E>(mut self, data: F) -> Self pub fn data_factory<F, Out>(mut self, data: F) -> Self
where where
F: Fn() -> Out + 'static, F: Fn() -> Out + 'static,
Out: Future<Output = Result<D, E>> + 'static, Out: IntoFuture + 'static,
D: 'static, Out::Error: std::fmt::Debug,
E: std::fmt::Debug,
{ {
self.data_factories.push(Box::new(move || { self.data_factories.push(Box::new(move || {
{ Box::new(
let fut = data(); data()
async move { .into_future()
match fut.await { .map_err(|e| {
Err(e) => { log::error!("Can not construct data instance: {:?}", e);
log::error!("Can not construct data instance: {:?}", e); })
Err(()) .map(|data| {
} let data: Box<dyn DataFactory> = Box::new(Data::new(data));
Ok(data) => { data
let data: Box<dyn DataFactory> = Box::new(Data::new(data)); }),
Ok(data) )
}
}
}
}
.boxed_local()
})); }));
self self
} }
@@ -192,7 +183,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, App, HttpResponse}; /// use actix_web::{web, App, HttpResponse};
/// ///
/// async fn index(data: web::Path<(String, String)>) -> &'static str { /// fn index(data: web::Path<(String, String)>) -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
@@ -247,7 +238,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, App, HttpResponse}; /// use actix_web::{web, App, HttpResponse};
/// ///
/// async fn index() -> &'static str { /// fn index() -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
@@ -276,8 +267,8 @@ where
/// ``` /// ```
pub fn default_service<F, U>(mut self, f: F) -> Self pub fn default_service<F, U>(mut self, f: F) -> Self
where where
F: IntoServiceFactory<U>, F: IntoNewService<U>,
U: ServiceFactory< U: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse, Response = ServiceResponse,
@@ -286,9 +277,11 @@ where
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
// create and configure default resource // create and configure default resource
self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( self.default = Some(Rc::new(boxed::new_service(
|e| log::error!("Can not construct default service: {:?}", e), f.into_new_service().map_init_err(|e| {
)))); log::error!("Can not construct default service: {:?}", e)
}),
)));
self self
} }
@@ -302,7 +295,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, App, HttpRequest, HttpResponse, Result}; /// use actix_web::{web, App, HttpRequest, HttpResponse, Result};
/// ///
/// async fn index(req: HttpRequest) -> Result<HttpResponse> { /// fn index(req: HttpRequest) -> Result<HttpResponse> {
/// let url = req.url_for("youtube", &["asdlkjqme"])?; /// let url = req.url_for("youtube", &["asdlkjqme"])?;
/// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme"); /// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme");
/// Ok(HttpResponse::Ok().into()) /// Ok(HttpResponse::Ok().into())
@@ -343,10 +336,11 @@ where
/// ///
/// ```rust /// ```rust
/// use actix_service::Service; /// use actix_service::Service;
/// # use futures::Future;
/// use actix_web::{middleware, web, App}; /// use actix_web::{middleware, web, App};
/// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue};
/// ///
/// async fn index() -> &'static str { /// fn index() -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
@@ -356,11 +350,11 @@ where
/// .route("/index.html", web::get().to(index)); /// .route("/index.html", web::get().to(index));
/// } /// }
/// ``` /// ```
pub fn wrap<M, B1>( pub fn wrap<M, B1, F>(
self, self,
mw: M, mw: F,
) -> App< ) -> App<
impl ServiceFactory< impl NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B1>, Response = ServiceResponse<B1>,
@@ -378,9 +372,11 @@ where
InitError = (), InitError = (),
>, >,
B1: MessageBody, B1: MessageBody,
F: IntoTransform<M, T::Service>,
{ {
let endpoint = apply_transform(mw, self.endpoint);
App { App {
endpoint: apply(mw, self.endpoint), endpoint,
data: self.data, data: self.data,
data_factories: self.data_factories, data_factories: self.data_factories,
services: self.services, services: self.services,
@@ -401,25 +397,23 @@ where
/// ///
/// ```rust /// ```rust
/// use actix_service::Service; /// use actix_service::Service;
/// # use futures::Future;
/// use actix_web::{web, App}; /// use actix_web::{web, App};
/// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue};
/// ///
/// async fn index() -> &'static str { /// fn index() -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
/// fn main() { /// fn main() {
/// let app = App::new() /// let app = App::new()
/// .wrap_fn(|req, srv| { /// .wrap_fn(|req, srv|
/// let fut = srv.call(req); /// srv.call(req).map(|mut res| {
/// async {
/// let mut res = fut.await?;
/// res.headers_mut().insert( /// res.headers_mut().insert(
/// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// CONTENT_TYPE, HeaderValue::from_static("text/plain"),
/// ); /// );
/// Ok(res) /// res
/// } /// }))
/// })
/// .route("/index.html", web::get().to(index)); /// .route("/index.html", web::get().to(index));
/// } /// }
/// ``` /// ```
@@ -427,7 +421,7 @@ where
self, self,
mw: F, mw: F,
) -> App< ) -> App<
impl ServiceFactory< impl NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B1>, Response = ServiceResponse<B1>,
@@ -439,26 +433,16 @@ where
where where
B1: MessageBody, B1: MessageBody,
F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone,
R: Future<Output = Result<ServiceResponse<B1>, Error>>, R: IntoFuture<Item = ServiceResponse<B1>, Error = Error>,
{ {
App { self.wrap(mw)
endpoint: apply_fn_factory(self.endpoint, mw),
data: self.data,
data_factories: self.data_factories,
services: self.services,
default: self.default,
factory_ref: self.factory_ref,
config: self.config,
external: self.external,
_t: PhantomData,
}
} }
} }
impl<T, B> IntoServiceFactory<AppInit<T, B>> for App<T, B> impl<T, B> IntoNewService<AppInit<T, B>> for App<T, B>
where where
B: MessageBody, B: MessageBody,
T: ServiceFactory< T: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -466,7 +450,7 @@ where
InitError = (), InitError = (),
>, >,
{ {
fn into_factory(self) -> AppInit<T, B> { fn into_new_service(self) -> AppInit<T, B> {
AppInit { AppInit {
data: Rc::new(self.data), data: Rc::new(self.data),
data_factories: Rc::new(self.data_factories), data_factories: Rc::new(self.data_factories),
@@ -484,89 +468,82 @@ where
mod tests { mod tests {
use actix_service::Service; use actix_service::Service;
use bytes::Bytes; use bytes::Bytes;
use futures::future::{ok, Future}; use futures::{Future, IntoFuture};
use super::*; use super::*;
use crate::http::{header, HeaderValue, Method, StatusCode}; use crate::http::{header, HeaderValue, Method, StatusCode};
use crate::middleware::DefaultHeaders;
use crate::service::{ServiceRequest, ServiceResponse}; use crate::service::{ServiceRequest, ServiceResponse};
use crate::test::{block_on, call_service, init_service, read_body, TestRequest}; use crate::test::{
block_fn, block_on, call_service, init_service, read_body, TestRequest,
};
use crate::{web, Error, HttpRequest, HttpResponse}; use crate::{web, Error, HttpRequest, HttpResponse};
#[test] #[test]
fn test_default_resource() { fn test_default_resource() {
block_on(async { let mut srv = init_service(
let mut srv = init_service( App::new().service(web::resource("/test").to(|| HttpResponse::Ok())),
App::new().service(web::resource("/test").to(|| HttpResponse::Ok())), );
) let req = TestRequest::with_uri("/test").to_request();
.await; let resp = block_fn(|| srv.call(req)).unwrap();
let req = TestRequest::with_uri("/test").to_request(); assert_eq!(resp.status(), StatusCode::OK);
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::with_uri("/blah").to_request(); let req = TestRequest::with_uri("/blah").to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let mut srv = init_service( let mut srv = init_service(
App::new() App::new()
.service(web::resource("/test").to(|| HttpResponse::Ok())) .service(web::resource("/test").to(|| HttpResponse::Ok()))
.service( .service(
web::resource("/test2") web::resource("/test2")
.default_service(|r: ServiceRequest| { .default_service(|r: ServiceRequest| {
ok(r.into_response(HttpResponse::Created())) r.into_response(HttpResponse::Created())
}) })
.route(web::get().to(|| HttpResponse::Ok())), .route(web::get().to(|| HttpResponse::Ok())),
) )
.default_service(|r: ServiceRequest| { .default_service(|r: ServiceRequest| {
ok(r.into_response(HttpResponse::MethodNotAllowed())) r.into_response(HttpResponse::MethodNotAllowed())
}), }),
) );
.await;
let req = TestRequest::with_uri("/blah").to_request(); let req = TestRequest::with_uri("/blah").to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let req = TestRequest::with_uri("/test2").to_request(); let req = TestRequest::with_uri("/test2").to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::with_uri("/test2") let req = TestRequest::with_uri("/test2")
.method(Method::POST) .method(Method::POST)
.to_request(); .to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
})
} }
#[test] #[test]
fn test_data_factory() { fn test_data_factory() {
block_on(async { let mut srv =
let mut srv = init_service(App::new().data_factory(|| Ok::<_, ()>(10usize)).service(
init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service( web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), ));
)) let req = TestRequest::default().to_request();
.await; let resp = block_on(srv.call(req)).unwrap();
let req = TestRequest::default().to_request(); assert_eq!(resp.status(), StatusCode::OK);
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let mut srv = let mut srv =
init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service( init_service(App::new().data_factory(|| Ok::<_, ()>(10u32)).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)) ));
.await; let req = TestRequest::default().to_request();
let req = TestRequest::default().to_request(); let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
fn md<S, B>( fn md<S, B>(
req: ServiceRequest, req: ServiceRequest,
srv: &mut S, srv: &mut S,
) -> impl Future<Output = Result<ServiceResponse<B>, Error>> ) -> impl IntoFuture<Item = ServiceResponse<B>, Error = Error>
where where
S: Service< S: Service<
Request = ServiceRequest, Request = ServiceRequest,
@@ -574,141 +551,112 @@ mod tests {
Error = Error, Error = Error,
>, >,
{ {
let fut = srv.call(req); srv.call(req).map(|mut res| {
async move {
let mut res = fut.await?;
res.headers_mut() res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(res) res
} })
} }
#[test] #[test]
fn test_wrap() { fn test_wrap() {
block_on(async { let mut srv = init_service(
let mut srv = App::new()
init_service( .wrap(md)
App::new() .route("/test", web::get().to(|| HttpResponse::Ok())),
.wrap(DefaultHeaders::new().header( );
header::CONTENT_TYPE, let req = TestRequest::with_uri("/test").to_request();
HeaderValue::from_static("0001"), let resp = call_service(&mut srv, req);
)) assert_eq!(resp.status(), StatusCode::OK);
.route("/test", web::get().to(|| HttpResponse::Ok())), assert_eq!(
) resp.headers().get(header::CONTENT_TYPE).unwrap(),
.await; HeaderValue::from_static("0001")
let req = TestRequest::with_uri("/test").to_request(); );
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
} }
#[test] #[test]
fn test_router_wrap() { fn test_router_wrap() {
block_on(async { let mut srv = init_service(
let mut srv = App::new()
init_service( .route("/test", web::get().to(|| HttpResponse::Ok()))
App::new() .wrap(md),
.route("/test", web::get().to(|| HttpResponse::Ok())) );
.wrap(DefaultHeaders::new().header( let req = TestRequest::with_uri("/test").to_request();
header::CONTENT_TYPE, let resp = call_service(&mut srv, req);
HeaderValue::from_static("0001"), assert_eq!(resp.status(), StatusCode::OK);
)), assert_eq!(
) resp.headers().get(header::CONTENT_TYPE).unwrap(),
.await; HeaderValue::from_static("0001")
let req = TestRequest::with_uri("/test").to_request(); );
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
} }
#[test] #[test]
fn test_wrap_fn() { fn test_wrap_fn() {
block_on(async { let mut srv = init_service(
let mut srv = init_service( App::new()
App::new() .wrap_fn(|req, srv| {
.wrap_fn(|req, srv| { srv.call(req).map(|mut res| {
let fut = srv.call(req); res.headers_mut().insert(
async move { header::CONTENT_TYPE,
let mut res = fut.await?; HeaderValue::from_static("0001"),
res.headers_mut().insert( );
header::CONTENT_TYPE, res
HeaderValue::from_static("0001"),
);
Ok(res)
}
}) })
.service(web::resource("/test").to(|| HttpResponse::Ok())), })
) .service(web::resource("/test").to(|| HttpResponse::Ok())),
.await; );
let req = TestRequest::with_uri("/test").to_request(); let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req).await; let resp = call_service(&mut srv, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!( assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(), resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001") HeaderValue::from_static("0001")
); );
})
} }
#[test] #[test]
fn test_router_wrap_fn() { fn test_router_wrap_fn() {
block_on(async { let mut srv = init_service(
let mut srv = init_service( App::new()
App::new() .route("/test", web::get().to(|| HttpResponse::Ok()))
.route("/test", web::get().to(|| HttpResponse::Ok())) .wrap_fn(|req, srv| {
.wrap_fn(|req, srv| { srv.call(req).map(|mut res| {
let fut = srv.call(req); res.headers_mut().insert(
async { header::CONTENT_TYPE,
let mut res = fut.await?; HeaderValue::from_static("0001"),
res.headers_mut().insert( );
header::CONTENT_TYPE, res
HeaderValue::from_static("0001"), })
); }),
Ok(res) );
} let req = TestRequest::with_uri("/test").to_request();
}), let resp = call_service(&mut srv, req);
) assert_eq!(resp.status(), StatusCode::OK);
.await; assert_eq!(
let req = TestRequest::with_uri("/test").to_request(); resp.headers().get(header::CONTENT_TYPE).unwrap(),
let resp = call_service(&mut srv, req).await; HeaderValue::from_static("0001")
assert_eq!(resp.status(), StatusCode::OK); );
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
} }
#[test] #[test]
fn test_external_resource() { fn test_external_resource() {
block_on(async { let mut srv = init_service(
let mut srv = init_service( App::new()
App::new() .external_resource("youtube", "https://youtube.com/watch/{video_id}")
.external_resource("youtube", "https://youtube.com/watch/{video_id}") .route(
.route( "/test",
"/test", web::get().to(|req: HttpRequest| {
web::get().to(|req: HttpRequest| { HttpResponse::Ok().body(format!(
HttpResponse::Ok().body(format!( "{}",
"{}", req.url_for("youtube", &["12345"]).unwrap()
req.url_for("youtube", &["12345"]).unwrap() ))
)) }),
}), ),
), );
) let req = TestRequest::with_uri("/test").to_request();
.await; let resp = call_service(&mut srv, req);
let req = TestRequest::with_uri("/test").to_request(); assert_eq!(resp.status(), StatusCode::OK);
let resp = call_service(&mut srv, req).await; let body = read_body(resp);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
let body = read_body(resp).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
})
} }
} }

View File

@@ -1,16 +1,14 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_http::{Extensions, Request, Response}; use actix_http::{Extensions, Request, Response};
use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url}; use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url};
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::boxed::{self, BoxedNewService, BoxedService}; use actix_service::boxed::{self, BoxedNewService, BoxedService};
use actix_service::{service_fn, Service, ServiceFactory}; use actix_service::{service_fn, NewService, Service};
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, FutureResult};
use futures::{Async, Future, Poll};
use crate::config::{AppConfig, AppService}; use crate::config::{AppConfig, AppService};
use crate::data::DataFactory; use crate::data::DataFactory;
@@ -18,20 +16,23 @@ use crate::error::Error;
use crate::guard::Guard; use crate::guard::Guard;
use crate::request::{HttpRequest, HttpRequestPool}; use crate::request::{HttpRequest, HttpRequestPool};
use crate::rmap::ResourceMap; use crate::rmap::ResourceMap;
use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; use crate::service::{ServiceFactory, ServiceRequest, ServiceResponse};
type Guards = Vec<Box<dyn Guard>>; type Guards = Vec<Box<dyn Guard>>;
type HttpService = BoxedService<ServiceRequest, ServiceResponse, Error>; type HttpService = BoxedService<ServiceRequest, ServiceResponse, Error>;
type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>;
type BoxedResponse = LocalBoxFuture<'static, Result<ServiceResponse, Error>>; type BoxedResponse = Either<
FutureResult<ServiceResponse, Error>,
Box<dyn Future<Item = ServiceResponse, Error = Error>>,
>;
type FnDataFactory = type FnDataFactory =
Box<dyn Fn() -> LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>; Box<dyn Fn() -> Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>;
/// Service factory to convert `Request` to a `ServiceRequest<S>`. /// Service factory to convert `Request` to a `ServiceRequest<S>`.
/// It also executes data factories. /// It also executes data factories.
pub struct AppInit<T, B> pub struct AppInit<T, B>
where where
T: ServiceFactory< T: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -43,15 +44,15 @@ where
pub(crate) data: Rc<Vec<Box<dyn DataFactory>>>, pub(crate) data: Rc<Vec<Box<dyn DataFactory>>>,
pub(crate) data_factories: Rc<Vec<FnDataFactory>>, pub(crate) data_factories: Rc<Vec<FnDataFactory>>,
pub(crate) config: RefCell<AppConfig>, pub(crate) config: RefCell<AppConfig>,
pub(crate) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory>>>>, pub(crate) services: Rc<RefCell<Vec<Box<dyn ServiceFactory>>>>,
pub(crate) default: Option<Rc<HttpNewService>>, pub(crate) default: Option<Rc<HttpNewService>>,
pub(crate) factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>, pub(crate) factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>,
pub(crate) external: RefCell<Vec<ResourceDef>>, pub(crate) external: RefCell<Vec<ResourceDef>>,
} }
impl<T, B> ServiceFactory for AppInit<T, B> impl<T, B> NewService for AppInit<T, B>
where where
T: ServiceFactory< T: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -70,8 +71,8 @@ where
fn new_service(&self, cfg: &ServerConfig) -> Self::Future { fn new_service(&self, cfg: &ServerConfig) -> Self::Future {
// update resource default service // update resource default service
let default = self.default.clone().unwrap_or_else(|| { let default = self.default.clone().unwrap_or_else(|| {
Rc::new(boxed::factory(service_fn(|req: ServiceRequest| { Rc::new(boxed::new_service(service_fn(|req: ServiceRequest| {
ok(req.into_response(Response::NotFound().finish())) Ok(req.into_response(Response::NotFound().finish()))
}))) })))
}); });
@@ -134,25 +135,23 @@ where
} }
} }
#[pin_project::pin_project]
pub struct AppInitResult<T, B> pub struct AppInitResult<T, B>
where where
T: ServiceFactory, T: NewService,
{ {
endpoint: Option<T::Service>, endpoint: Option<T::Service>,
#[pin]
endpoint_fut: T::Future, endpoint_fut: T::Future,
rmap: Rc<ResourceMap>, rmap: Rc<ResourceMap>,
config: AppConfig, config: AppConfig,
data: Rc<Vec<Box<dyn DataFactory>>>, data: Rc<Vec<Box<dyn DataFactory>>>,
data_factories: Vec<Box<dyn DataFactory>>, data_factories: Vec<Box<dyn DataFactory>>,
data_factories_fut: Vec<LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>, data_factories_fut: Vec<Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>,
_t: PhantomData<B>, _t: PhantomData<B>,
} }
impl<T, B> Future for AppInitResult<T, B> impl<T, B> Future for AppInitResult<T, B>
where where
T: ServiceFactory< T: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -160,49 +159,48 @@ where
InitError = (), InitError = (),
>, >,
{ {
type Output = Result<AppInitService<T::Service, B>, ()>; type Item = AppInitService<T::Service, B>;
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
// async data factories // async data factories
let mut idx = 0; let mut idx = 0;
while idx < this.data_factories_fut.len() { while idx < self.data_factories_fut.len() {
match Pin::new(&mut this.data_factories_fut[idx]).poll(cx)? { match self.data_factories_fut[idx].poll()? {
Poll::Ready(f) => { Async::Ready(f) => {
this.data_factories.push(f); self.data_factories.push(f);
let _ = this.data_factories_fut.remove(idx); let _ = self.data_factories_fut.remove(idx);
} }
Poll::Pending => idx += 1, Async::NotReady => idx += 1,
} }
} }
if this.endpoint.is_none() { if self.endpoint.is_none() {
if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? { if let Async::Ready(srv) = self.endpoint_fut.poll()? {
*this.endpoint = Some(srv); self.endpoint = Some(srv);
} }
} }
if this.endpoint.is_some() && this.data_factories_fut.is_empty() { if self.endpoint.is_some() && self.data_factories_fut.is_empty() {
// create app data container // create app data container
let mut data = Extensions::new(); let mut data = Extensions::new();
for f in this.data.iter() { for f in self.data.iter() {
f.create(&mut data); f.create(&mut data);
} }
for f in this.data_factories.iter() { for f in &self.data_factories {
f.create(&mut data); f.create(&mut data);
} }
Poll::Ready(Ok(AppInitService { Ok(Async::Ready(AppInitService {
service: this.endpoint.take().unwrap(), service: self.endpoint.take().unwrap(),
rmap: this.rmap.clone(), rmap: self.rmap.clone(),
config: this.config.clone(), config: self.config.clone(),
data: Rc::new(data), data: Rc::new(data),
pool: HttpRequestPool::create(), pool: HttpRequestPool::create(),
})) }))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -228,8 +226,8 @@ where
type Error = T::Error; type Error = T::Error;
type Future = T::Future; type Future = T::Future;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: Request) -> Self::Future { fn call(&mut self, req: Request) -> Self::Future {
@@ -272,7 +270,7 @@ pub struct AppRoutingFactory {
default: Rc<HttpNewService>, default: Rc<HttpNewService>,
} }
impl ServiceFactory for AppRoutingFactory { impl NewService for AppRoutingFactory {
type Config = (); type Config = ();
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
@@ -290,7 +288,7 @@ impl ServiceFactory for AppRoutingFactory {
CreateAppRoutingItem::Future( CreateAppRoutingItem::Future(
Some(path.clone()), Some(path.clone()),
guards.borrow_mut().take(), guards.borrow_mut().take(),
service.new_service(&()).boxed_local(), service.new_service(&()),
) )
}) })
.collect(), .collect(),
@@ -300,14 +298,14 @@ impl ServiceFactory for AppRoutingFactory {
} }
} }
type HttpServiceFut = LocalBoxFuture<'static, Result<HttpService, ()>>; type HttpServiceFut = Box<dyn Future<Item = HttpService, Error = ()>>;
/// Create app service /// Create app service
#[doc(hidden)] #[doc(hidden)]
pub struct AppRoutingFactoryResponse { pub struct AppRoutingFactoryResponse {
fut: Vec<CreateAppRoutingItem>, fut: Vec<CreateAppRoutingItem>,
default: Option<HttpService>, default: Option<HttpService>,
default_fut: Option<LocalBoxFuture<'static, Result<HttpService, ()>>>, default_fut: Option<Box<dyn Future<Item = HttpService, Error = ()>>>,
} }
enum CreateAppRoutingItem { enum CreateAppRoutingItem {
@@ -316,15 +314,16 @@ enum CreateAppRoutingItem {
} }
impl Future for AppRoutingFactoryResponse { impl Future for AppRoutingFactoryResponse {
type Output = Result<AppRouting, ()>; type Item = AppRouting;
type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut done = true; let mut done = true;
if let Some(ref mut fut) = self.default_fut { if let Some(ref mut fut) = self.default_fut {
match Pin::new(fut).poll(cx)? { match fut.poll()? {
Poll::Ready(default) => self.default = Some(default), Async::Ready(default) => self.default = Some(default),
Poll::Pending => done = false, Async::NotReady => done = false,
} }
} }
@@ -335,12 +334,11 @@ impl Future for AppRoutingFactoryResponse {
ref mut path, ref mut path,
ref mut guards, ref mut guards,
ref mut fut, ref mut fut,
) => match Pin::new(fut).poll(cx) { ) => match fut.poll()? {
Poll::Ready(Ok(service)) => { Async::Ready(service) => {
Some((path.take().unwrap(), guards.take(), service)) Some((path.take().unwrap(), guards.take(), service))
} }
Poll::Ready(Err(_)) => return Poll::Ready(Err(())), Async::NotReady => {
Poll::Pending => {
done = false; done = false;
None None
} }
@@ -366,13 +364,13 @@ impl Future for AppRoutingFactoryResponse {
} }
router router
}); });
Poll::Ready(Ok(AppRouting { Ok(Async::Ready(AppRouting {
ready: None, ready: None,
router: router.finish(), router: router.finish(),
default: self.default.take(), default: self.default.take(),
})) }))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -389,11 +387,11 @@ impl Service for AppRouting {
type Error = Error; type Error = Error;
type Future = BoxedResponse; type Future = BoxedResponse;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
if self.ready.is_none() { if self.ready.is_none() {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
@@ -415,7 +413,7 @@ impl Service for AppRouting {
default.call(req) default.call(req)
} else { } else {
let req = req.into_parts().0; let req = req.into_parts().0;
ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local() Either::A(ok(ServiceResponse::new(req, Response::NotFound().finish())))
} }
} }
} }
@@ -431,7 +429,7 @@ impl AppEntry {
} }
} }
impl ServiceFactory for AppEntry { impl NewService for AppEntry {
type Config = (); type Config = ();
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
@@ -466,16 +464,15 @@ mod tests {
#[test] #[test]
fn drop_data() { fn drop_data() {
let data = Arc::new(AtomicBool::new(false)); let data = Arc::new(AtomicBool::new(false));
test::block_on(async { {
let mut app = test::init_service( let mut app = test::init_service(
App::new() App::new()
.data(DropData(data.clone())) .data(DropData(data.clone()))
.service(web::resource("/test").to(|| HttpResponse::Ok())), .service(web::resource("/test").to(|| HttpResponse::Ok())),
) );
.await;
let req = test::TestRequest::with_uri("/test").to_request(); let req = test::TestRequest::with_uri("/test").to_request();
let _ = app.call(req).await.unwrap(); let _ = test::block_on(app.call(req)).unwrap();
}); }
assert!(data.load(Ordering::Relaxed)); assert!(data.load(Ordering::Relaxed));
} }
} }

View File

@@ -3,7 +3,7 @@ use std::rc::Rc;
use actix_http::Extensions; use actix_http::Extensions;
use actix_router::ResourceDef; use actix_router::ResourceDef;
use actix_service::{boxed, IntoServiceFactory, ServiceFactory}; use actix_service::{boxed, IntoNewService, NewService};
use crate::data::{Data, DataFactory}; use crate::data::{Data, DataFactory};
use crate::error::Error; use crate::error::Error;
@@ -12,7 +12,7 @@ use crate::resource::Resource;
use crate::rmap::ResourceMap; use crate::rmap::ResourceMap;
use crate::route::Route; use crate::route::Route;
use crate::service::{ use crate::service::{
AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest,
ServiceResponse, ServiceResponse,
}; };
@@ -102,11 +102,11 @@ impl AppService {
&mut self, &mut self,
rdef: ResourceDef, rdef: ResourceDef,
guards: Option<Vec<Box<dyn Guard>>>, guards: Option<Vec<Box<dyn Guard>>>,
factory: F, service: F,
nested: Option<Rc<ResourceMap>>, nested: Option<Rc<ResourceMap>>,
) where ) where
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S: ServiceFactory< S: NewService<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse, Response = ServiceResponse,
@@ -116,7 +116,7 @@ impl AppService {
{ {
self.services.push(( self.services.push((
rdef, rdef,
boxed::factory(factory.into_factory()), boxed::new_service(service.into_new_service()),
guards, guards,
nested, nested,
)); ));
@@ -174,7 +174,7 @@ impl Default for AppConfigInner {
/// to set of external methods. This could help with /// to set of external methods. This could help with
/// modularization of big application configuration. /// modularization of big application configuration.
pub struct ServiceConfig { pub struct ServiceConfig {
pub(crate) services: Vec<Box<dyn AppServiceFactory>>, pub(crate) services: Vec<Box<dyn ServiceFactory>>,
pub(crate) data: Vec<Box<dyn DataFactory>>, pub(crate) data: Vec<Box<dyn DataFactory>>,
pub(crate) external: Vec<ResourceDef>, pub(crate) external: Vec<ResourceDef>,
} }
@@ -251,19 +251,17 @@ mod tests {
#[test] #[test]
fn test_data() { fn test_data() {
block_on(async { let cfg = |cfg: &mut ServiceConfig| {
let cfg = |cfg: &mut ServiceConfig| { cfg.data(10usize);
cfg.data(10usize); };
};
let mut srv = init_service(App::new().configure(cfg).service( let mut srv =
init_service(App::new().configure(cfg).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)) ));
.await; let req = TestRequest::default().to_request();
let req = TestRequest::default().to_request(); let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.status(), StatusCode::OK);
})
} }
// #[test] // #[test]
@@ -300,57 +298,50 @@ mod tests {
#[test] #[test]
fn test_external_resource() { fn test_external_resource() {
block_on(async { let mut srv = init_service(
let mut srv = init_service( App::new()
App::new() .configure(|cfg| {
.configure(|cfg| { cfg.external_resource(
cfg.external_resource( "youtube",
"youtube", "https://youtube.com/watch/{video_id}",
"https://youtube.com/watch/{video_id}", );
); })
}) .route(
.route( "/test",
"/test", web::get().to(|req: HttpRequest| {
web::get().to(|req: HttpRequest| { HttpResponse::Ok().body(format!(
HttpResponse::Ok().body(format!( "{}",
"{}", req.url_for("youtube", &["12345"]).unwrap()
req.url_for("youtube", &["12345"]).unwrap() ))
)) }),
}), ),
), );
) let req = TestRequest::with_uri("/test").to_request();
.await; let resp = call_service(&mut srv, req);
let req = TestRequest::with_uri("/test").to_request(); assert_eq!(resp.status(), StatusCode::OK);
let resp = call_service(&mut srv, req).await; let body = read_body(resp);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
let body = read_body(resp).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
})
} }
#[test] #[test]
fn test_service() { fn test_service() {
block_on(async { let mut srv = init_service(App::new().configure(|cfg| {
let mut srv = init_service(App::new().configure(|cfg| { cfg.service(
cfg.service( web::resource("/test").route(web::get().to(|| HttpResponse::Created())),
web::resource("/test") )
.route(web::get().to(|| HttpResponse::Created())), .route("/index.html", web::get().to(|| HttpResponse::Ok()));
) }));
.route("/index.html", web::get().to(|| HttpResponse::Ok()));
}))
.await;
let req = TestRequest::with_uri("/test") let req = TestRequest::with_uri("/test")
.method(Method::GET) .method(Method::GET)
.to_request(); .to_request();
let resp = call_service(&mut srv, req).await; let resp = call_service(&mut srv, req);
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
let req = TestRequest::with_uri("/index.html") let req = TestRequest::with_uri("/index.html")
.method(Method::GET) .method(Method::GET)
.to_request(); .to_request();
let resp = call_service(&mut srv, req).await; let resp = call_service(&mut srv, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
} }

View File

@@ -3,7 +3,6 @@ use std::sync::Arc;
use actix_http::error::{Error, ErrorInternalServerError}; use actix_http::error::{Error, ErrorInternalServerError};
use actix_http::Extensions; use actix_http::Extensions;
use futures::future::{err, ok, Ready};
use crate::dev::Payload; use crate::dev::Payload;
use crate::extract::FromRequest; use crate::extract::FromRequest;
@@ -45,7 +44,7 @@ pub(crate) trait DataFactory {
/// } /// }
/// ///
/// /// Use `Data<T>` extractor to access data in handler. /// /// Use `Data<T>` extractor to access data in handler.
/// async fn index(data: web::Data<Mutex<MyData>>) { /// fn index(data: web::Data<Mutex<MyData>>) {
/// let mut data = data.lock().unwrap(); /// let mut data = data.lock().unwrap();
/// data.counter += 1; /// data.counter += 1;
/// } /// }
@@ -102,19 +101,19 @@ impl<T> Clone for Data<T> {
impl<T: 'static> FromRequest for Data<T> { impl<T: 'static> FromRequest for Data<T> {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<Self, Error>>; type Future = Result<Self, Error>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
if let Some(st) = req.get_app_data::<T>() { if let Some(st) = req.get_app_data::<T>() {
ok(st) Ok(st)
} else { } else {
log::debug!( log::debug!(
"Failed to construct App-level Data extractor. \ "Failed to construct App-level Data extractor. \
Request path: {:?}", Request path: {:?}",
req.path() req.path()
); );
err(ErrorInternalServerError( Err(ErrorInternalServerError(
"App data is not configured, to configure use App::data()", "App data is not configured, to configure use App::data()",
)) ))
} }
@@ -143,99 +142,85 @@ mod tests {
#[test] #[test]
fn test_data_extractor() { fn test_data_extractor() {
block_on(async { let mut srv =
let mut srv = init_service(App::new().data(10usize).service( init_service(App::new().data(10usize).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)) ));
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let mut srv = init_service(App::new().data(10u32).service( let mut srv =
init_service(App::new().data(10u32).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)) ));
.await; let req = TestRequest::default().to_request();
let req = TestRequest::default().to_request(); let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
#[test] #[test]
fn test_register_data_extractor() { fn test_register_data_extractor() {
block_on(async { let mut srv =
let mut srv = init_service(App::new().register_data(Data::new(10usize)).service(
init_service(App::new().register_data(Data::new(10usize)).service( web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), ));
))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let mut srv = let mut srv =
init_service(App::new().register_data(Data::new(10u32)).service( init_service(App::new().register_data(Data::new(10u32)).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)) ));
.await; let req = TestRequest::default().to_request();
let req = TestRequest::default().to_request(); let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
#[test] #[test]
fn test_route_data_extractor() { fn test_route_data_extractor() {
block_on(async { let mut srv =
let mut srv = init_service(App::new().service( init_service(App::new().service(web::resource("/").data(10usize).route(
web::resource("/").data(10usize).route(web::get().to( web::get().to(|data: web::Data<usize>| {
|data: web::Data<usize>| { let _ = data.clone();
let _ = data.clone(); HttpResponse::Ok()
HttpResponse::Ok() }),
}, )));
)),
))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
// different type // different type
let mut srv = init_service( let mut srv = init_service(
App::new().service( App::new().service(
web::resource("/") web::resource("/")
.data(10u32) .data(10u32)
.route(web::get().to(|_: web::Data<usize>| HttpResponse::Ok())), .route(web::get().to(|_: web::Data<usize>| HttpResponse::Ok())),
), ),
) );
.await; let req = TestRequest::default().to_request();
let req = TestRequest::default().to_request(); let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
#[test] #[test]
fn test_override_data() { fn test_override_data() {
block_on(async { let mut srv = init_service(App::new().data(1usize).service(
let mut srv = init_service(App::new().data(1usize).service( web::resource("/").data(10usize).route(web::get().to(
web::resource("/").data(10usize).route(web::get().to( |data: web::Data<usize>| {
|data: web::Data<usize>| { assert_eq!(*data, 10);
assert_eq!(*data, 10); let _ = data.clone();
let _ = data.clone(); HttpResponse::Ok()
HttpResponse::Ok() },
}, )),
)), ));
))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap(); let resp = block_on(srv.call(req)).unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
} }

View File

@@ -1,10 +1,8 @@
//! Request extractors //! Request extractors
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_http::error::Error; use actix_http::error::Error;
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::ok;
use futures::{future, Async, Future, IntoFuture, Poll};
use crate::dev::Payload; use crate::dev::Payload;
use crate::request::HttpRequest; use crate::request::HttpRequest;
@@ -17,7 +15,7 @@ pub trait FromRequest: Sized {
type Error: Into<Error>; type Error: Into<Error>;
/// Future that resolves to a Self /// Future that resolves to a Self
type Future: Future<Output = Result<Self, Self::Error>>; type Future: IntoFuture<Item = Self, Error = Self::Error>;
/// Configuration for this extractor /// Configuration for this extractor
type Config: Default + 'static; type Config: Default + 'static;
@@ -50,7 +48,6 @@ pub trait FromRequest: Sized {
/// ```rust /// ```rust
/// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest}; /// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest};
/// use actix_web::error::ErrorBadRequest; /// use actix_web::error::ErrorBadRequest;
/// use futures::future::{ok, err, Ready};
/// use serde_derive::Deserialize; /// use serde_derive::Deserialize;
/// use rand; /// use rand;
/// ///
@@ -61,21 +58,21 @@ pub trait FromRequest: Sized {
/// ///
/// impl FromRequest for Thing { /// impl FromRequest for Thing {
/// type Error = Error; /// type Error = Error;
/// type Future = Ready<Result<Self, Self::Error>>; /// type Future = Result<Self, Self::Error>;
/// type Config = (); /// type Config = ();
/// ///
/// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
/// if rand::random() { /// if rand::random() {
/// ok(Thing { name: "thingy".into() }) /// Ok(Thing { name: "thingy".into() })
/// } else { /// } else {
/// err(ErrorBadRequest("no luck")) /// Err(ErrorBadRequest("no luck"))
/// } /// }
/// ///
/// } /// }
/// } /// }
/// ///
/// /// extract `Thing` from request /// /// extract `Thing` from request
/// async fn index(supplied_thing: Option<Thing>) -> String { /// fn index(supplied_thing: Option<Thing>) -> String {
/// match supplied_thing { /// match supplied_thing {
/// // Puns not intended /// // Puns not intended
/// Some(thing) => format!("Got something: {:?}", thing), /// Some(thing) => format!("Got something: {:?}", thing),
@@ -97,19 +94,21 @@ where
{ {
type Config = T::Config; type Config = T::Config;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Option<T>, Error>>; type Future = Box<dyn Future<Item = Option<T>, Error = Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
T::from_request(req, payload) Box::new(
.then(|r| match r { T::from_request(req, payload)
Ok(v) => ok(Some(v)), .into_future()
Err(e) => { .then(|r| match r {
log::debug!("Error for Option<T> extractor: {}", e.into()); Ok(v) => future::ok(Some(v)),
ok(None) Err(e) => {
} log::debug!("Error for Option<T> extractor: {}", e.into());
}) future::ok(None)
.boxed_local() }
}),
)
} }
} }
@@ -122,7 +121,6 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest}; /// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest};
/// use actix_web::error::ErrorBadRequest; /// use actix_web::error::ErrorBadRequest;
/// use futures::future::{ok, err, Ready};
/// use serde_derive::Deserialize; /// use serde_derive::Deserialize;
/// use rand; /// use rand;
/// ///
@@ -133,20 +131,20 @@ where
/// ///
/// impl FromRequest for Thing { /// impl FromRequest for Thing {
/// type Error = Error; /// type Error = Error;
/// type Future = Ready<Result<Thing, Error>>; /// type Future = Result<Thing, Error>;
/// type Config = (); /// type Config = ();
/// ///
/// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
/// if rand::random() { /// if rand::random() {
/// ok(Thing { name: "thingy".into() }) /// Ok(Thing { name: "thingy".into() })
/// } else { /// } else {
/// err(ErrorBadRequest("no luck")) /// Err(ErrorBadRequest("no luck"))
/// } /// }
/// } /// }
/// } /// }
/// ///
/// /// extract `Thing` from request /// /// extract `Thing` from request
/// async fn index(supplied_thing: Result<Thing>) -> String { /// fn index(supplied_thing: Result<Thing>) -> String {
/// match supplied_thing { /// match supplied_thing {
/// Ok(thing) => format!("Got thing: {:?}", thing), /// Ok(thing) => format!("Got thing: {:?}", thing),
/// Err(e) => format!("Error extracting thing: {}", e) /// Err(e) => format!("Error extracting thing: {}", e)
@@ -159,24 +157,26 @@ where
/// ); /// );
/// } /// }
/// ``` /// ```
impl<T> FromRequest for Result<T, T::Error> impl<T: 'static> FromRequest for Result<T, T::Error>
where where
T: FromRequest + 'static, T: FromRequest,
T::Error: 'static,
T::Future: 'static, T::Future: 'static,
T::Error: 'static,
{ {
type Config = T::Config; type Config = T::Config;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Result<T, T::Error>, Error>>; type Future = Box<dyn Future<Item = Result<T, T::Error>, Error = Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
T::from_request(req, payload) Box::new(
.then(|res| match res { T::from_request(req, payload)
Ok(v) => ok(Ok(v)), .into_future()
Err(e) => ok(Err(e)), .then(|res| match res {
}) Ok(v) => ok(Ok(v)),
.boxed_local() Err(e) => ok(Err(e)),
}),
)
} }
} }
@@ -184,10 +184,10 @@ where
impl FromRequest for () { impl FromRequest for () {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<(), Error>>; type Future = Result<(), Error>;
fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future {
ok(()) Ok(())
} }
} }
@@ -204,44 +204,43 @@ macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => {
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
$fut_type { $fut_type {
items: <($(Option<$T>,)+)>::default(), items: <($(Option<$T>,)+)>::default(),
futs: ($($T::from_request(req, payload),)+), futs: ($($T::from_request(req, payload).into_future(),)+),
} }
} }
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project]
pub struct $fut_type<$($T: FromRequest),+> { pub struct $fut_type<$($T: FromRequest),+> {
items: ($(Option<$T>,)+), items: ($(Option<$T>,)+),
futs: ($($T::Future,)+), futs: ($(<$T::Future as futures::IntoFuture>::Future,)+),
} }
impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> impl<$($T: FromRequest),+> Future for $fut_type<$($T),+>
{ {
type Output = Result<($($T,)+), Error>; type Item = ($($T,)+);
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut ready = true; let mut ready = true;
$( $(
if this.items.$n.is_none() { if self.items.$n.is_none() {
match unsafe { Pin::new_unchecked(&mut this.futs.$n) }.poll(cx) { match self.futs.$n.poll() {
Poll::Ready(Ok(item)) => { Ok(Async::Ready(item)) => {
this.items.$n = Some(item); self.items.$n = Some(item);
} }
Poll::Pending => ready = false, Ok(Async::NotReady) => ready = false,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), Err(e) => return Err(e.into()),
} }
} }
)+ )+
if ready { if ready {
Poll::Ready(Ok( Ok(Async::Ready(
($(this.items.$n.take().unwrap(),)+) ($(self.items.$n.take().unwrap(),)+)
)) ))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }

View File

@@ -1,34 +1,28 @@
use std::convert::Infallible; use std::convert::Infallible;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_http::{Error, Response}; use actix_http::{Error, Response};
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult};
use futures::ready; use futures::{try_ready, Async, Future, IntoFuture, Poll};
use pin_project::pin_project;
use crate::extract::FromRequest; use crate::extract::FromRequest;
use crate::request::HttpRequest; use crate::request::HttpRequest;
use crate::responder::Responder; use crate::responder::Responder;
use crate::service::{ServiceRequest, ServiceResponse}; use crate::service::{ServiceRequest, ServiceResponse};
/// Async handler converter factory /// Handler converter factory
pub trait Factory<T, R, O>: Clone + 'static pub trait Factory<T, R>: Clone
where where
R: Future<Output = O>, R: Responder,
O: Responder,
{ {
fn call(&self, param: T) -> R; fn call(&self, param: T) -> R;
} }
impl<F, R, O> Factory<(), R, O> for F impl<F, R> Factory<(), R> for F
where where
F: Fn() -> R + Clone + 'static, F: Fn() -> R + Clone,
R: Future<Output = O>, R: Responder,
O: Responder,
{ {
fn call(&self, _: ()) -> R { fn call(&self, _: ()) -> R {
(self)() (self)()
@@ -36,21 +30,19 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
pub struct Handler<F, T, R, O> pub struct Handler<F, T, R>
where where
F: Factory<T, R, O>, F: Factory<T, R>,
R: Future<Output = O>, R: Responder,
O: Responder,
{ {
hnd: F, hnd: F,
_t: PhantomData<(T, R, O)>, _t: PhantomData<(T, R)>,
} }
impl<F, T, R, O> Handler<F, T, R, O> impl<F, T, R> Handler<F, T, R>
where where
F: Factory<T, R, O>, F: Factory<T, R>,
R: Future<Output = O>, R: Responder,
O: Responder,
{ {
pub fn new(hnd: F) -> Self { pub fn new(hnd: F) -> Self {
Handler { Handler {
@@ -60,38 +52,156 @@ where
} }
} }
impl<F, T, R, O> Clone for Handler<F, T, R, O> impl<F, T, R> Clone for Handler<F, T, R>
where where
F: Factory<T, R, O>, F: Factory<T, R>,
R: Future<Output = O>, R: Responder,
O: Responder,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Handler { Self {
hnd: self.hnd.clone(), hnd: self.hnd.clone(),
_t: PhantomData, _t: PhantomData,
} }
} }
} }
impl<F, T, R, O> Service for Handler<F, T, R, O> impl<F, T, R> Service for Handler<F, T, R>
where where
F: Factory<T, R, O>, F: Factory<T, R>,
R: Future<Output = O>, R: Responder,
O: Responder,
{ {
type Request = (T, HttpRequest); type Request = (T, HttpRequest);
type Response = ServiceResponse; type Response = ServiceResponse;
type Error = Infallible; type Error = Infallible;
type Future = HandlerServiceResponse<R, O>; type Future = HandlerServiceResponse<<R::Future as IntoFuture>::Future>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future {
let fut = self.hnd.call(param).respond_to(&req).into_future();
HandlerServiceResponse { HandlerServiceResponse {
fut: self.hnd.call(param), fut,
req: Some(req),
}
}
}
pub struct HandlerServiceResponse<T> {
fut: T,
req: Option<HttpRequest>,
}
impl<T> Future for HandlerServiceResponse<T>
where
T: Future<Item = Response>,
T::Error: Into<Error>,
{
type Item = ServiceResponse;
type Error = Infallible;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.fut.poll() {
Ok(Async::Ready(res)) => Ok(Async::Ready(ServiceResponse::new(
self.req.take().unwrap(),
res,
))),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
let res: Response = e.into().into();
Ok(Async::Ready(ServiceResponse::new(
self.req.take().unwrap(),
res,
)))
}
}
}
}
/// Async handler converter factory
pub trait AsyncFactory<T, R>: Clone + 'static
where
R: IntoFuture,
R::Item: Responder,
R::Error: Into<Error>,
{
fn call(&self, param: T) -> R;
}
impl<F, R> AsyncFactory<(), R> for F
where
F: Fn() -> R + Clone + 'static,
R: IntoFuture,
R::Item: Responder,
R::Error: Into<Error>,
{
fn call(&self, _: ()) -> R {
(self)()
}
}
#[doc(hidden)]
pub struct AsyncHandler<F, T, R>
where
F: AsyncFactory<T, R>,
R: IntoFuture,
R::Item: Responder,
R::Error: Into<Error>,
{
hnd: F,
_t: PhantomData<(T, R)>,
}
impl<F, T, R> AsyncHandler<F, T, R>
where
F: AsyncFactory<T, R>,
R: IntoFuture,
R::Item: Responder,
R::Error: Into<Error>,
{
pub fn new(hnd: F) -> Self {
AsyncHandler {
hnd,
_t: PhantomData,
}
}
}
impl<F, T, R> Clone for AsyncHandler<F, T, R>
where
F: AsyncFactory<T, R>,
R: IntoFuture,
R::Item: Responder,
R::Error: Into<Error>,
{
fn clone(&self) -> Self {
AsyncHandler {
hnd: self.hnd.clone(),
_t: PhantomData,
}
}
}
impl<F, T, R> Service for AsyncHandler<F, T, R>
where
F: AsyncFactory<T, R>,
R: IntoFuture,
R::Item: Responder,
R::Error: Into<Error>,
{
type Request = (T, HttpRequest);
type Response = ServiceResponse;
type Error = Infallible;
type Future = AsyncHandlerServiceResponse<R::Future>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future {
AsyncHandlerServiceResponse {
fut: self.hnd.call(param).into_future(),
fut2: None, fut2: None,
req: Some(req), req: Some(req),
} }
@@ -99,49 +209,57 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project] pub struct AsyncHandlerServiceResponse<T>
pub struct HandlerServiceResponse<T, R>
where where
T: Future<Output = R>, T: Future,
R: Responder, T::Item: Responder,
{ {
#[pin]
fut: T, fut: T,
#[pin] fut2: Option<<<T::Item as Responder>::Future as IntoFuture>::Future>,
fut2: Option<R::Future>,
req: Option<HttpRequest>, req: Option<HttpRequest>,
} }
impl<T, R> Future for HandlerServiceResponse<T, R> impl<T> Future for AsyncHandlerServiceResponse<T>
where where
T: Future<Output = R>, T: Future,
R: Responder, T::Item: Responder,
T::Error: Into<Error>,
{ {
type Output = Result<ServiceResponse, Infallible>; type Item = ServiceResponse;
type Error = Infallible;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.as_mut().project(); if let Some(ref mut fut) = self.fut2 {
return match fut.poll() {
if let Some(fut) = this.fut2.as_pin_mut() { Ok(Async::Ready(res)) => Ok(Async::Ready(ServiceResponse::new(
return match fut.poll(cx) { self.req.take().unwrap(),
Poll::Ready(Ok(res)) => { res,
Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) ))),
} Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Pending => Poll::Pending, Err(e) => {
Poll::Ready(Err(e)) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) Ok(Async::Ready(ServiceResponse::new(
self.req.take().unwrap(),
res,
)))
} }
}; };
} }
match this.fut.poll(cx) { match self.fut.poll() {
Poll::Ready(res) => { Ok(Async::Ready(res)) => {
let fut = res.respond_to(this.req.as_ref().unwrap()); self.fut2 =
self.as_mut().project().fut2.set(Some(fut)); Some(res.respond_to(self.req.as_ref().unwrap()).into_future());
self.poll(cx) self.poll()
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
let res: Response = e.into().into();
Ok(Async::Ready(ServiceResponse::new(
self.req.take().unwrap(),
res,
)))
} }
Poll::Pending => Poll::Pending,
} }
} }
} }
@@ -161,7 +279,7 @@ impl<T: FromRequest, S> Extract<T, S> {
} }
} }
impl<T: FromRequest, S> ServiceFactory for Extract<T, S> impl<T: FromRequest, S> NewService for Extract<T, S>
where where
S: Service< S: Service<
Request = (T, HttpRequest), Request = (T, HttpRequest),
@@ -175,7 +293,7 @@ where
type Error = (Error, ServiceRequest); type Error = (Error, ServiceRequest);
type InitError = (); type InitError = ();
type Service = ExtractService<T, S>; type Service = ExtractService<T, S>;
type Future = Ready<Result<Self::Service, ()>>; type Future = FutureResult<Self::Service, ()>;
fn new_service(&self, _: &()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
ok(ExtractService { ok(ExtractService {
@@ -203,13 +321,13 @@ where
type Error = (Error, ServiceRequest); type Error = (Error, ServiceRequest);
type Future = ExtractResponse<T, S>; type Future = ExtractResponse<T, S>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
let (req, mut payload) = req.into_parts(); let (req, mut payload) = req.into_parts();
let fut = T::from_request(&req, &mut payload); let fut = T::from_request(&req, &mut payload).into_future();
ExtractResponse { ExtractResponse {
fut, fut,
@@ -220,13 +338,10 @@ where
} }
} }
#[pin_project]
pub struct ExtractResponse<T: FromRequest, S: Service> { pub struct ExtractResponse<T: FromRequest, S: Service> {
req: HttpRequest, req: HttpRequest,
service: S, service: S,
#[pin] fut: <T::Future as IntoFuture>::Future,
fut: T::Future,
#[pin]
fut_s: Option<S::Future>, fut_s: Option<S::Future>,
} }
@@ -238,35 +353,40 @@ where
Error = Infallible, Error = Infallible,
>, >,
{ {
type Output = Result<ServiceResponse, (Error, ServiceRequest)>; type Item = ServiceResponse;
type Error = (Error, ServiceRequest);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.as_mut().project(); if let Some(ref mut fut) = self.fut_s {
return fut.poll().map_err(|_| panic!());
if let Some(fut) = this.fut_s.as_pin_mut() {
return fut.poll(cx).map_err(|_| panic!());
} }
match ready!(this.fut.poll(cx)) { let item = try_ready!(self.fut.poll().map_err(|e| {
Err(e) => { let req = ServiceRequest::new(self.req.clone());
let req = ServiceRequest::new(this.req.clone()); (e.into(), req)
Poll::Ready(Err((e.into(), req))) }));
}
Ok(item) => { self.fut_s = Some(self.service.call((item, self.req.clone())));
let fut = Some(this.service.call((item, this.req.clone()))); self.poll()
self.as_mut().project().fut_s.set(fut);
self.poll(cx)
}
}
} }
} }
/// FromRequest trait impl for tuples /// FromRequest trait impl for tuples
macro_rules! factory_tuple ({ $(($n:tt, $T:ident)),+} => { macro_rules! factory_tuple ({ $(($n:tt, $T:ident)),+} => {
impl<Func, $($T,)+ Res, O> Factory<($($T,)+), Res, O> for Func impl<Func, $($T,)+ Res> Factory<($($T,)+), Res> for Func
where Func: Fn($($T,)+) -> Res + Clone,
Res: Responder,
{
fn call(&self, param: ($($T,)+)) -> Res {
(self)($(param.$n,)+)
}
}
impl<Func, $($T,)+ Res> AsyncFactory<($($T,)+), Res> for Func
where Func: Fn($($T,)+) -> Res + Clone + 'static, where Func: Fn($($T,)+) -> Res + Clone + 'static,
Res: Future<Output = O>, Res: IntoFuture,
O: Responder, Res::Item: Responder,
Res::Error: Into<Error>,
{ {
fn call(&self, param: ($($T,)+)) -> Res { fn call(&self, param: ($($T,)+)) -> Res {
(self)($(param.$n,)+) (self)($(param.$n,)+)

View File

@@ -1,4 +1,4 @@
#![allow(clippy::borrow_interior_mutable_const, unused_imports, dead_code)] #![allow(clippy::borrow_interior_mutable_const)]
//! Actix web is a small, pragmatic, and extremely fast web framework //! Actix web is a small, pragmatic, and extremely fast web framework
//! for Rust. //! for Rust.
//! //!
@@ -6,7 +6,7 @@
//! use actix_web::{web, App, Responder, HttpServer}; //! use actix_web::{web, App, Responder, HttpServer};
//! # use std::thread; //! # use std::thread;
//! //!
//! async fn index(info: web::Path<(String, u32)>) -> impl Responder { //! fn index(info: web::Path<(String, u32)>) -> impl Responder {
//! format!("Hello {}! id:{}", info.0, info.1) //! format!("Hello {}! id:{}", info.0, info.1)
//! } //! }
//! //!
@@ -68,8 +68,8 @@
//! ## Package feature //! ## Package feature
//! //!
//! * `client` - enables http client (default enabled) //! * `client` - enables http client (default enabled)
//! * `openssl` - enables ssl support via `openssl` crate, supports `http/2` //! * `ssl` - enables ssl support via `openssl` crate, supports `http/2`
//! * `rustls` - enables ssl support via `rustls` crate, supports `http/2` //! * `rust-tls` - enables ssl support via `rustls` crate, supports `http/2`
//! * `secure-cookies` - enables secure cookies support, includes `ring` crate as //! * `secure-cookies` - enables secure cookies support, includes `ring` crate as
//! dependency //! dependency
//! * `brotli` - enables `brotli` compression support, requires `c` //! * `brotli` - enables `brotli` compression support, requires `c`
@@ -78,6 +78,7 @@
//! `c` compiler (default enabled) //! `c` compiler (default enabled)
//! * `flate2-rust` - experimental rust based implementation for //! * `flate2-rust` - experimental rust based implementation for
//! `gzip`, `deflate` compression. //! `gzip`, `deflate` compression.
//! * `uds` - Unix domain support, enables `HttpServer::bind_uds()` method.
//! //!
#![allow(clippy::type_complexity, clippy::new_without_default)] #![allow(clippy::type_complexity, clippy::new_without_default)]
@@ -136,16 +137,15 @@ pub mod dev {
pub use crate::config::{AppConfig, AppService}; pub use crate::config::{AppConfig, AppService};
#[doc(hidden)] #[doc(hidden)]
pub use crate::handler::Factory; pub use crate::handler::{AsyncFactory, Factory};
pub use crate::info::ConnectionInfo; pub use crate::info::ConnectionInfo;
pub use crate::rmap::ResourceMap; pub use crate::rmap::ResourceMap;
pub use crate::service::{ pub use crate::service::{
HttpServiceFactory, ServiceRequest, ServiceResponse, WebService, HttpServiceFactory, ServiceRequest, ServiceResponse, WebService,
}; };
pub use crate::types::form::UrlEncoded;
//pub use crate::types::form::UrlEncoded; pub use crate::types::json::JsonBody;
//pub use crate::types::json::JsonBody; pub use crate::types::readlines::Readlines;
//pub use crate::types::readlines::Readlines;
pub use actix_http::body::{Body, BodySize, MessageBody, ResponseBody, SizedStream}; pub use actix_http::body::{Body, BodySize, MessageBody, ResponseBody, SizedStream};
pub use actix_http::encoding::Decoder as Decompress; pub use actix_http::encoding::Decoder as Decompress;
@@ -171,20 +171,23 @@ pub mod client {
//! An HTTP Client //! An HTTP Client
//! //!
//! ```rust //! ```rust
//! use futures::future::{Future, lazy};
//! use actix_rt::System; //! use actix_rt::System;
//! use actix_web::client::Client; //! use actix_web::client::Client;
//! //!
//! fn main() { //! fn main() {
//! System::new("test").block_on(async { //! System::new("test").block_on(lazy(|| {
//! let mut client = Client::default(); //! let mut client = Client::default();
//! //!
//! // Create request builder and send request //! client.get("http://www.rust-lang.org") // <- Create request builder
//! let response = client.get("http://www.rust-lang.org")
//! .header("User-Agent", "Actix-web") //! .header("User-Agent", "Actix-web")
//! .send().await; // <- Send http request //! .send() // <- Send http request
//! //! .map_err(|_| ())
//! println!("Response: {:?}", response); //! .and_then(|response| { // <- server http response
//! }); //! println!("Response: {:?}", response);
//! Ok(())
//! })
//! }));
//! } //! }
//! ``` //! ```
pub use awc::error::{ pub use awc::error::{

View File

@@ -1,18 +1,15 @@
//! `Middleware` for compressing response body. //! `Middleware` for compressing response body.
use std::cmp; use std::cmp;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use std::task::{Context, Poll};
use actix_http::body::MessageBody; use actix_http::body::MessageBody;
use actix_http::encoding::Encoder; use actix_http::encoding::Encoder;
use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING}; use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING};
use actix_http::{Error, Response, ResponseBuilder}; use actix_http::{Error, Response, ResponseBuilder};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult};
use pin_project::pin_project; use futures::{Async, Future, Poll};
use crate::service::{ServiceRequest, ServiceResponse}; use crate::service::{ServiceRequest, ServiceResponse};
@@ -81,7 +78,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = CompressMiddleware<S>; type Transform = CompressMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CompressMiddleware { ok(CompressMiddleware {
@@ -106,8 +103,8 @@ where
type Error = Error; type Error = Error;
type Future = CompressResponse<S, B>; type Future = CompressResponse<S, B>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
@@ -131,13 +128,11 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project]
pub struct CompressResponse<S, B> pub struct CompressResponse<S, B>
where where
S: Service, S: Service,
B: MessageBody, B: MessageBody,
{ {
#[pin]
fut: S::Future, fut: S::Future,
encoding: ContentEncoding, encoding: ContentEncoding,
_t: PhantomData<(B)>, _t: PhantomData<(B)>,
@@ -148,25 +143,21 @@ where
B: MessageBody, B: MessageBody,
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{ {
type Output = Result<ServiceResponse<Encoder<B>>, Error>; type Item = ServiceResponse<Encoder<B>>;
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.project(); let resp = futures::try_ready!(self.fut.poll());
match futures::ready!(this.fut.poll(cx)) { let enc = if let Some(enc) = resp.response().extensions().get::<Enc>() {
Ok(resp) => { enc.0
let enc = if let Some(enc) = resp.response().extensions().get::<Enc>() { } else {
enc.0 self.encoding
} else { };
*this.encoding
};
Poll::Ready(Ok( Ok(Async::Ready(resp.map_body(move |head, body| {
resp.map_body(move |head, body| Encoder::response(enc, head, body)) Encoder::response(enc, head, body)
)) })))
}
Err(e) => Poll::Ready(Err(e)),
}
} }
} }

View File

@@ -1,8 +1,7 @@
//! `Middleware` for conditionally enables another middleware. //! `Middleware` for conditionally enables another middleware.
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, FutureResult, Map};
use futures::{Future, Poll};
/// `Middleware` for conditionally enables another middleware. /// `Middleware` for conditionally enables another middleware.
/// The controled middleware must not change the `Service` interfaces. /// The controled middleware must not change the `Service` interfaces.
@@ -14,11 +13,11 @@ use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
/// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::middleware::{Condition, NormalizePath};
/// use actix_web::App; /// use actix_web::App;
/// ///
/// # fn main() { /// fn main() {
/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); /// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into());
/// let app = App::new() /// let app = App::new()
/// .wrap(Condition::new(enable_normalize, NormalizePath)); /// .wrap(Condition::new(enable_normalize, NormalizePath));
/// # } /// }
/// ``` /// ```
pub struct Condition<T> { pub struct Condition<T> {
trans: T, trans: T,
@@ -33,31 +32,29 @@ impl<T> Condition<T> {
impl<S, T> Transform<S> for Condition<T> impl<S, T> Transform<S> for Condition<T>
where where
S: Service + 'static, S: Service,
T: Transform<S, Request = S::Request, Response = S::Response, Error = S::Error>, T: Transform<S, Request = S::Request, Response = S::Response, Error = S::Error>,
T::Future: 'static,
T::InitError: 'static,
T::Transform: 'static,
{ {
type Request = S::Request; type Request = S::Request;
type Response = S::Response; type Response = S::Response;
type Error = S::Error; type Error = S::Error;
type InitError = T::InitError; type InitError = T::InitError;
type Transform = ConditionMiddleware<T::Transform, S>; type Transform = ConditionMiddleware<T::Transform, S>;
type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>; type Future = Either<
Map<T::Future, fn(T::Transform) -> Self::Transform>,
FutureResult<Self::Transform, Self::InitError>,
>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
if self.enable { if self.enable {
let f = self.trans.new_transform(service).map(|res| { let f = self
res.map( .trans
ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform, .new_transform(service)
) .map(ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform);
}); Either::A(f)
Either::Left(f)
} else { } else {
Either::Right(ok(ConditionMiddleware::Disable(service))) Either::B(ok(ConditionMiddleware::Disable(service)))
} }
.boxed_local()
} }
} }
@@ -76,19 +73,19 @@ where
type Error = E::Error; type Error = E::Error;
type Future = Either<E::Future, D::Future>; type Future = Either<E::Future, D::Future>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
use ConditionMiddleware::*; use ConditionMiddleware::*;
match self { match self {
Enable(service) => service.poll_ready(cx), Enable(service) => service.poll_ready(),
Disable(service) => service.poll_ready(cx), Disable(service) => service.poll_ready(),
} }
} }
fn call(&mut self, req: E::Request) -> Self::Future { fn call(&mut self, req: E::Request) -> Self::Future {
use ConditionMiddleware::*; use ConditionMiddleware::*;
match self { match self {
Enable(service) => Either::Left(service.call(req)), Enable(service) => Either::A(service.call(req)),
Disable(service) => Either::Right(service.call(req)), Disable(service) => Either::B(service.call(req)),
} }
} }
} }
@@ -102,7 +99,7 @@ mod tests {
use crate::error::Result; use crate::error::Result;
use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode};
use crate::middleware::errhandlers::*; use crate::middleware::errhandlers::*;
use crate::test::{self, block_on, TestRequest}; use crate::test::{self, TestRequest};
use crate::HttpResponse; use crate::HttpResponse;
fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> { fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
@@ -114,44 +111,33 @@ mod tests {
#[test] #[test]
fn test_handler_enabled() { fn test_handler_enabled() {
block_on(async { let srv = |req: ServiceRequest| {
let srv = |req: ServiceRequest| { req.into_response(HttpResponse::InternalServerError().finish())
ok(req.into_response(HttpResponse::InternalServerError().finish())) };
};
let mw = ErrorHandlers::new() let mw =
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mut mw = Condition::new(true, mw) let mut mw =
.new_transform(srv.into_service()) test::block_on(Condition::new(true, mw).new_transform(srv.into_service()))
.await
.unwrap(); .unwrap();
let resp = let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request());
test::call_service(&mut mw, TestRequest::default().to_srv_request()) assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
.await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
})
} }
#[test] #[test]
fn test_handler_disabled() { fn test_handler_disabled() {
block_on(async { let srv = |req: ServiceRequest| {
let srv = |req: ServiceRequest| { req.into_response(HttpResponse::InternalServerError().finish())
ok(req.into_response(HttpResponse::InternalServerError().finish())) };
};
let mw = ErrorHandlers::new() let mw =
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mut mw = Condition::new(false, mw) let mut mw =
.new_transform(srv.into_service()) test::block_on(Condition::new(false, mw).new_transform(srv.into_service()))
.await
.unwrap(); .unwrap();
let resp = let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request());
test::call_service(&mut mw, TestRequest::default().to_srv_request()) assert_eq!(resp.headers().get(CONTENT_TYPE), None);
.await;
assert_eq!(resp.headers().get(CONTENT_TYPE), None);
})
} }
} }

View File

@@ -1,11 +1,9 @@
//! Middleware for setting default response headers //! Middleware for setting default response headers
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, FutureResult};
use futures::{Future, Poll};
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use crate::http::{HeaderMap, HttpTryFrom}; use crate::http::{HeaderMap, HttpTryFrom};
@@ -98,7 +96,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = DefaultHeadersMiddleware<S>; type Transform = DefaultHeadersMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(DefaultHeadersMiddleware { ok(DefaultHeadersMiddleware {
@@ -121,19 +119,16 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
let fut = self.service.call(req);
async move {
let mut res = fut.await?;
Box::new(self.service.call(req).map(move |mut res| {
// set response headers // set response headers
for (key, value) in inner.headers.iter() { for (key, value) in inner.headers.iter() {
if !res.headers().contains_key(key) { if !res.headers().contains_key(key) {
@@ -147,16 +142,15 @@ where
HeaderValue::from_static("application/octet-stream"), HeaderValue::from_static("application/octet-stream"),
); );
} }
Ok(res)
} res
.boxed_local() }))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::IntoService; use actix_service::IntoService;
use futures::future::ok;
use super::*; use super::*;
use crate::dev::ServiceRequest; use crate::dev::ServiceRequest;
@@ -166,50 +160,46 @@ mod tests {
#[test] #[test]
fn test_default_headers() { fn test_default_headers() {
block_on(async { let mut mw = block_on(
let mut mw = DefaultHeaders::new() DefaultHeaders::new()
.header(CONTENT_TYPE, "0001") .header(CONTENT_TYPE, "0001")
.new_transform(ok_service()) .new_transform(ok_service()),
.await )
.unwrap(); .unwrap();
let req = TestRequest::default().to_srv_request(); let req = TestRequest::default().to_srv_request();
let resp = mw.call(req).await.unwrap(); let resp = block_on(mw.call(req)).unwrap();
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let req = TestRequest::default().to_srv_request(); let req = TestRequest::default().to_srv_request();
let srv = |req: ServiceRequest| { let srv = |req: ServiceRequest| {
ok(req.into_response( req.into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish())
HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish(), };
)) let mut mw = block_on(
}; DefaultHeaders::new()
let mut mw = DefaultHeaders::new()
.header(CONTENT_TYPE, "0001") .header(CONTENT_TYPE, "0001")
.new_transform(srv.into_service()) .new_transform(srv.into_service()),
.await )
.unwrap(); .unwrap();
let resp = mw.call(req).await.unwrap(); let resp = block_on(mw.call(req)).unwrap();
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
})
} }
#[test] #[test]
fn test_content_type() { fn test_content_type() {
block_on(async { let srv = |req: ServiceRequest| req.into_response(HttpResponse::Ok().finish());
let srv = let mut mw = block_on(
|req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); DefaultHeaders::new()
let mut mw = DefaultHeaders::new()
.content_type() .content_type()
.new_transform(srv.into_service()) .new_transform(srv.into_service()),
.await )
.unwrap(); .unwrap();
let req = TestRequest::default().to_srv_request(); let req = TestRequest::default().to_srv_request();
let resp = mw.call(req).await.unwrap(); let resp = block_on(mw.call(req)).unwrap();
assert_eq!( assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(), resp.headers().get(CONTENT_TYPE).unwrap(),
"application/octet-stream" "application/octet-stream"
); );
})
} }
} }

View File

@@ -1,9 +1,9 @@
//! Custom handlers service for responses. //! Custom handlers service for responses.
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{err, ok, Either, Future, FutureExt, LocalBoxFuture, Ready}; use futures::future::{err, ok, Either, Future, FutureResult};
use futures::Poll;
use hashbrown::HashMap; use hashbrown::HashMap;
use crate::dev::{ServiceRequest, ServiceResponse}; use crate::dev::{ServiceRequest, ServiceResponse};
@@ -15,7 +15,7 @@ pub enum ErrorHandlerResponse<B> {
/// New http response got generated /// New http response got generated
Response(ServiceResponse<B>), Response(ServiceResponse<B>),
/// Result is a future that resolves to a new http response /// Result is a future that resolves to a new http response
Future(LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>), Future(Box<dyn Future<Item = ServiceResponse<B>, Error = Error>>),
} }
type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>; type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
@@ -39,17 +39,17 @@ type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse
/// Ok(ErrorHandlerResponse::Response(res)) /// Ok(ErrorHandlerResponse::Response(res))
/// } /// }
/// ///
/// # fn main() { /// fn main() {
/// let app = App::new() /// let app = App::new()
/// .wrap( /// .wrap(
/// ErrorHandlers::new() /// ErrorHandlers::new()
/// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500), /// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500),
/// ) /// )
/// .service(web::resource("/test") /// .service(web::resource("/test")
/// .route(web::get().to(|| HttpResponse::Ok())) /// .route(web::get().to(|| HttpResponse::Ok()))
/// .route(web::head().to(|| HttpResponse::MethodNotAllowed()) /// .route(web::head().to(|| HttpResponse::MethodNotAllowed())
/// )); /// ));
/// # } /// }
/// ``` /// ```
pub struct ErrorHandlers<B> { pub struct ErrorHandlers<B> {
handlers: Rc<HashMap<StatusCode, Box<ErrorHandler<B>>>>, handlers: Rc<HashMap<StatusCode, Box<ErrorHandler<B>>>>,
@@ -92,7 +92,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = ErrorHandlersMiddleware<S, B>; type Transform = ErrorHandlersMiddleware<S, B>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(ErrorHandlersMiddleware { ok(ErrorHandlersMiddleware {
@@ -117,30 +117,26 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
let handlers = self.handlers.clone(); let handlers = self.handlers.clone();
let fut = self.service.call(req);
async move {
let res = fut.await?;
Box::new(self.service.call(req).and_then(move |res| {
if let Some(handler) = handlers.get(&res.status()) { if let Some(handler) = handlers.get(&res.status()) {
match handler(res) { match handler(res) {
Ok(ErrorHandlerResponse::Response(res)) => Ok(res), Ok(ErrorHandlerResponse::Response(res)) => Either::A(ok(res)),
Ok(ErrorHandlerResponse::Future(fut)) => fut.await, Ok(ErrorHandlerResponse::Future(fut)) => Either::B(fut),
Err(e) => Err(e), Err(e) => Either::A(err(e)),
} }
} else { } else {
Ok(res) Either::A(ok(res))
} }
} }))
.boxed_local()
} }
} }
@@ -151,7 +147,7 @@ mod tests {
use super::*; use super::*;
use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode};
use crate::test::{self, block_on, TestRequest}; use crate::test::{self, TestRequest};
use crate::HttpResponse; use crate::HttpResponse;
fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> { fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
@@ -163,22 +159,19 @@ mod tests {
#[test] #[test]
fn test_handler() { fn test_handler() {
block_on(async { let srv = |req: ServiceRequest| {
let srv = |req: ServiceRequest| { req.into_response(HttpResponse::InternalServerError().finish())
ok(req.into_response(HttpResponse::InternalServerError().finish())) };
};
let mut mw = ErrorHandlers::new() let mut mw = test::block_on(
ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500)
.new_transform(srv.into_service()) .new_transform(srv.into_service()),
.await )
.unwrap(); .unwrap();
let resp = let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request());
test::call_service(&mut mw, TestRequest::default().to_srv_request()) assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
.await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
})
} }
fn render_500_async<B: 'static>( fn render_500_async<B: 'static>(
@@ -187,26 +180,23 @@ mod tests {
res.response_mut() res.response_mut()
.headers_mut() .headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001")); .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) Ok(ErrorHandlerResponse::Future(Box::new(ok(res))))
} }
#[test] #[test]
fn test_handler_async() { fn test_handler_async() {
block_on(async { let srv = |req: ServiceRequest| {
let srv = |req: ServiceRequest| { req.into_response(HttpResponse::InternalServerError().finish())
ok(req.into_response(HttpResponse::InternalServerError().finish())) };
};
let mut mw = ErrorHandlers::new() let mut mw = test::block_on(
ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async)
.new_transform(srv.into_service()) .new_transform(srv.into_service()),
.await )
.unwrap(); .unwrap();
let resp = let resp = test::call_service(&mut mw, TestRequest::default().to_srv_request());
test::call_service(&mut mw, TestRequest::default().to_srv_request()) assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
.await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
})
} }
} }

View File

@@ -2,15 +2,13 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::env; use std::env;
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{ok, Ready}; use futures::future::{ok, FutureResult};
use futures::{Async, Future, Poll};
use log::debug; use log::debug;
use regex::Regex; use regex::Regex;
use time; use time;
@@ -127,7 +125,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = LoggerMiddleware<S>; type Transform = LoggerMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(LoggerMiddleware { ok(LoggerMiddleware {
@@ -153,8 +151,8 @@ where
type Error = Error; type Error = Error;
type Future = LoggerResponse<S, B>; type Future = LoggerResponse<S, B>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
@@ -183,13 +181,11 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project]
pub struct LoggerResponse<S, B> pub struct LoggerResponse<S, B>
where where
B: MessageBody, B: MessageBody,
S: Service, S: Service,
{ {
#[pin]
fut: S::Future, fut: S::Future,
time: time::Tm, time: time::Tm,
format: Option<Format>, format: Option<Format>,
@@ -201,15 +197,11 @@ where
B: MessageBody, B: MessageBody,
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
{ {
type Output = Result<ServiceResponse<StreamLog<B>>, Error>; type Item = ServiceResponse<StreamLog<B>>;
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.project(); let res = futures::try_ready!(self.fut.poll());
let res = match futures::ready!(this.fut.poll(cx)) {
Ok(res) => res,
Err(e) => return Poll::Ready(Err(e)),
};
if let Some(error) = res.response().error() { if let Some(error) = res.response().error() {
if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR { if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR {
@@ -217,21 +209,18 @@ where
} }
} }
if let Some(ref mut format) = this.format { if let Some(ref mut format) = self.format {
for unit in &mut format.0 { for unit in &mut format.0 {
unit.render_response(res.response()); unit.render_response(res.response());
} }
} }
let time = *this.time; Ok(Async::Ready(res.map_body(move |_, body| {
let format = this.format.take();
Poll::Ready(Ok(res.map_body(move |_, body| {
ResponseBody::Body(StreamLog { ResponseBody::Body(StreamLog {
body, body,
time,
format,
size: 0, size: 0,
time: self.time,
format: self.format.take(),
}) })
}))) })))
} }
@@ -263,13 +252,13 @@ impl<B: MessageBody> MessageBody for StreamLog<B> {
self.body.size() self.body.size()
} }
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
match self.body.poll_next(cx) { match self.body.poll_next()? {
Poll::Ready(Some(Ok(chunk))) => { Async::Ready(Some(chunk)) => {
self.size += chunk.len(); self.size += chunk.len();
Poll::Ready(Some(Ok(chunk))) Ok(Async::Ready(Some(chunk)))
} }
val => val, val => Ok(val),
} }
} }
} }
@@ -475,7 +464,6 @@ impl<'a> fmt::Display for FormatDisplay<'a> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::{IntoService, Service, Transform}; use actix_service::{IntoService, Service, Transform};
use futures::future::ok;
use super::*; use super::*;
use crate::http::{header, StatusCode}; use crate::http::{header, StatusCode};
@@ -484,11 +472,11 @@ mod tests {
#[test] #[test]
fn test_logger() { fn test_logger() {
let srv = |req: ServiceRequest| { let srv = |req: ServiceRequest| {
ok(req.into_response( req.into_response(
HttpResponse::build(StatusCode::OK) HttpResponse::build(StatusCode::OK)
.header("X-Test", "ttt") .header("X-Test", "ttt")
.finish(), .finish(),
)) )
}; };
let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test");

Some files were not shown because too many files have changed in this diff Show More